From 9d622270a16070d4a7a21f058c1a9ad91dce65e2 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 2 Jan 2025 23:08:19 +0100 Subject: [PATCH] feat: custom congruence rule for equality in `grind` (#6510) This PR adds a custom congruence rule for equality in `grind`. The new rule takes into account that `Eq` is a symmetric relation. In the future, we will add support for arbitrary symmetric relations. The current rule is important for propagating disequalities effectively in `grind`. --- src/Init/Grind/Lemmas.lean | 3 ++ src/Lean/Meta/Tactic/Grind.lean | 1 + src/Lean/Meta/Tactic/Grind/Proof.lean | 78 +++++++++++++++++---------- src/Lean/Meta/Tactic/Grind/Types.lean | 34 ++++++++---- tests/lean/run/grind_diseq.lean | 5 ++ tests/lean/run/grind_ematch2.lean | 12 +++++ 6 files changed, 95 insertions(+), 38 deletions(-) create mode 100644 tests/lean/run/grind_diseq.lean diff --git a/src/Init/Grind/Lemmas.lean b/src/Init/Grind/Lemmas.lean index 4372e8c91848..d76e243ecf8d 100644 --- a/src/Init/Grind/Lemmas.lean +++ b/src/Init/Grind/Lemmas.lean @@ -51,6 +51,9 @@ theorem false_of_not_eq_self {a : Prop} (h : (Not a) = a) : False := by theorem eq_eq_of_eq_true_left {a b : Prop} (h : a = True) : (a = b) = b := by simp [h] theorem eq_eq_of_eq_true_right {a b : Prop} (h : b = True) : (a = b) = a := by simp [h] +theorem eq_congr {α : Sort u} {a₁ b₁ a₂ b₂ : α} (h₁ : a₁ = a₂) (h₂ : b₁ = b₂) : (a₁ = b₁) = (a₂ = b₂) := by simp [*] +theorem eq_congr' {α : Sort u} {a₁ b₁ a₂ b₂ : α} (h₁ : a₁ = b₂) (h₂ : b₁ = a₂) : (a₁ = b₁) = (a₂ = b₂) := by rw [h₁, h₂, Eq.comm (a := a₂)] + /-! Forall -/ theorem forall_propagator (p : Prop) (q : p → Prop) (q' : Prop) (h₁ : p = True) (h₂ : q (of_eq_true h₁) = q') : (∀ hp : p, q hp) = q' := by diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index 938341c3ec90..11b15243e263 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -46,5 +46,6 @@ builtin_initialize registerTraceClass `grind.debug.congr builtin_initialize registerTraceClass `grind.debug.proof builtin_initialize registerTraceClass `grind.debug.proj builtin_initialize registerTraceClass `grind.debug.parent +builtin_initialize registerTraceClass `grind.debug.final end Lean diff --git a/src/Lean/Meta/Tactic/Grind/Proof.lean b/src/Lean/Meta/Tactic/Grind/Proof.lean index 580cf2362dde..f395c3fe89dc 100644 --- a/src/Lean/Meta/Tactic/Grind/Proof.lean +++ b/src/Lean/Meta/Tactic/Grind/Proof.lean @@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude -import Lean.Meta.Sorry -- TODO: remove +import Init.Grind.Lemmas import Lean.Meta.Tactic.Grind.Types namespace Lean.Meta.Grind @@ -128,6 +128,52 @@ mutual let r := (← loop lhs rhs).get! if heq then mkHEqOfEq r else return r + private partial def mkHCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do + let f := lhs.getAppFn + let g := rhs.getAppFn + let numArgs := lhs.getAppNumArgs + assert! rhs.getAppNumArgs == numArgs + let thm ← mkHCongrWithArity f numArgs + assert! thm.argKinds.size == numArgs + let rec loop (lhs rhs : Expr) (i : Nat) : GoalM Expr := do + let i := i - 1 + if lhs.isApp then + let proof ← loop lhs.appFn! rhs.appFn! i + let a₁ := lhs.appArg! + let a₂ := rhs.appArg! + let k := thm.argKinds[i]! + return mkApp3 proof a₁ a₂ (← mkEqProofCore a₁ a₂ (k matches .heq)) + else + return thm.proof + let proof ← loop lhs rhs numArgs + if isSameExpr f g then + mkEqOfHEqIfNeeded proof heq + else + /- + `lhs` is of the form `f a_1 ... a_n` + `rhs` is of the form `g b_1 ... b_n` + `proof : HEq (f a_1 ... a_n) (f b_1 ... b_n)` + We construct a proof for `HEq (f a_1 ... a_n) (g b_1 ... b_n)` using `Eq.ndrec` + -/ + let motive ← withLocalDeclD (← mkFreshUserName `x) (← inferType f) fun x => do + mkLambdaFVars #[x] (← mkHEq lhs (mkAppN x rhs.getAppArgs)) + let fEq ← mkEqProofCore f g false + let proof ← mkEqNDRec motive proof fEq + mkEqOfHEqIfNeeded proof heq + + private partial def mkEqCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do + let_expr f@Eq α₁ a₁ b₁ := lhs | unreachable! + let_expr Eq α₂ a₂ b₂ := rhs | unreachable! + let enodes := (← get).enodes + let us := f.constLevels! + if !isSameExpr α₁ α₂ then + mkHCongrProof lhs rhs heq + else if hasSameRoot enodes a₁ a₂ && hasSameRoot enodes b₁ b₂ then + return mkApp7 (mkConst ``Grind.eq_congr us) α₁ a₁ b₁ a₂ b₂ (← mkEqProofCore a₁ a₂ false) (← mkEqProofCore b₁ b₂ false) + else + assert! hasSameRoot enodes a₁ b₂ && hasSameRoot enodes b₁ a₂ + return mkApp7 (mkConst ``Grind.eq_congr' us) α₁ a₁ b₁ a₂ b₂ (← mkEqProofCore a₁ b₂ false) (← mkEqProofCore b₁ a₂ false) + /-- Constructs a congruence proof for `lhs` and `rhs`. -/ private partial def mkCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do let f := lhs.getAppFn @@ -136,36 +182,12 @@ mutual assert! rhs.getAppNumArgs == numArgs if f.isConstOf ``Lean.Grind.nestedProof && g.isConstOf ``Lean.Grind.nestedProof && numArgs == 2 then mkNestedProofCongr lhs rhs heq + else if f.isConstOf ``Eq && g.isConstOf ``Eq && numArgs == 3 then + mkEqCongrProof lhs rhs heq else if (← isCongrDefaultProofTarget lhs rhs f g numArgs) then mkCongrDefaultProof lhs rhs heq else - let thm ← mkHCongrWithArity f numArgs - assert! thm.argKinds.size == numArgs - let rec loop (lhs rhs : Expr) (i : Nat) : GoalM Expr := do - let i := i - 1 - if lhs.isApp then - let proof ← loop lhs.appFn! rhs.appFn! i - let a₁ := lhs.appArg! - let a₂ := rhs.appArg! - let k := thm.argKinds[i]! - return mkApp3 proof a₁ a₂ (← mkEqProofCore a₁ a₂ (k matches .heq)) - else - return thm.proof - let proof ← loop lhs rhs numArgs - if isSameExpr f g then - mkEqOfHEqIfNeeded proof heq - else - /- - `lhs` is of the form `f a_1 ... a_n` - `rhs` is of the form `g b_1 ... b_n` - `proof : HEq (f a_1 ... a_n) (f b_1 ... b_n)` - We construct a proof for `HEq (f a_1 ... a_n) (g b_1 ... b_n)` using `Eq.ndrec` - -/ - let motive ← withLocalDeclD (← mkFreshUserName `x) (← inferType f) fun x => do - mkLambdaFVars #[x] (← mkHEq lhs (mkAppN x rhs.getAppArgs)) - let fEq ← mkEqProofCore f g false - let proof ← mkEqNDRec motive proof fEq - mkEqOfHEqIfNeeded proof heq + mkHCongrProof lhs rhs heq private partial def realizeEqProof (lhs rhs : Expr) (h : Expr) (flipped : Bool) (heq : Bool) : GoalM Expr := do let h ← if h == congrPlaceholderProof then diff --git a/src/Lean/Meta/Tactic/Grind/Types.lean b/src/Lean/Meta/Tactic/Grind/Types.lean index 1065c7b4d344..0a3329915afd 100644 --- a/src/Lean/Meta/Tactic/Grind/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Types.lean @@ -249,7 +249,7 @@ private def hashRoot (enodes : ENodeMap) (e : Expr) : UInt64 := else 13 -private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do +def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do if isSameExpr a b then return true else @@ -258,12 +258,15 @@ private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do isSameExpr n1.root n2.root def congrHash (enodes : ENodeMap) (e : Expr) : UInt64 := - if e.isAppOfArity ``Lean.Grind.nestedProof 2 then - -- We only hash the proposition - hashRoot enodes (e.getArg! 0) - else - go e 17 + match_expr e with + | Grind.nestedProof p _ => hashRoot enodes p + | Eq _ lhs rhs => goEq lhs rhs + | _ => go e 17 where + goEq (lhs rhs : Expr) : UInt64 := + let h₁ := hashRoot enodes lhs + let h₂ := hashRoot enodes rhs + if h₁ > h₂ then mixHash h₂ h₁ else mixHash h₁ h₂ go (e : Expr) (r : UInt64) : UInt64 := match e with | .app f a => go f (mixHash r (hashRoot enodes a)) @@ -271,11 +274,22 @@ where /-- Returns `true` if `a` and `b` are congruent modulo the equivalence classes in `enodes`. -/ partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool := - if a.isAppOfArity ``Lean.Grind.nestedProof 2 && b.isAppOfArity ``Lean.Grind.nestedProof 2 then - hasSameRoot enodes (a.getArg! 0) (b.getArg! 0) - else - go a b + match_expr a with + | Grind.nestedProof p₁ _ => + let_expr Grind.nestedProof p₂ _ := b | false + hasSameRoot enodes p₁ p₂ + | Eq α₁ lhs₁ rhs₁ => + let_expr Eq α₂ lhs₂ rhs₂ := b | false + if isSameExpr α₁ α₂ then + goEq lhs₁ rhs₁ lhs₂ rhs₂ + else + go a b + | _ => go a b where + goEq (lhs₁ rhs₁ lhs₂ rhs₂ : Expr) : Bool := + (hasSameRoot enodes lhs₁ lhs₂ && hasSameRoot enodes rhs₁ rhs₂) + || + (hasSameRoot enodes lhs₁ rhs₂ && hasSameRoot enodes rhs₁ lhs₂) go (a b : Expr) : Bool := if a.isApp && b.isApp then hasSameRoot enodes a.appArg! b.appArg! && go a.appFn! b.appFn! diff --git a/tests/lean/run/grind_diseq.lean b/tests/lean/run/grind_diseq.lean new file mode 100644 index 000000000000..724272154ede --- /dev/null +++ b/tests/lean/run/grind_diseq.lean @@ -0,0 +1,5 @@ +set_option grind.debug true + +example (p q : Prop) (a b c d : Nat) : + a = b → c = d → a ≠ c → (d ≠ b → p) → (d ≠ b → q) → p ∧ q := by + grind (splits:=0) diff --git a/tests/lean/run/grind_ematch2.lean b/tests/lean/run/grind_ematch2.lean index 2d4f1c850b98..06b6256ba1de 100644 --- a/tests/lean/run/grind_ematch2.lean +++ b/tests/lean/run/grind_ematch2.lean @@ -29,3 +29,15 @@ example (as bs cs : Array α) (v₁ v₂ : α) (h₆ : j < as.size) : cs[j] = as[j] := by grind + +example (as bs cs : Array α) (v₁ v₂ : α) + (i₁ i₂ j : Nat) + (h₁ : i₁ < as.size) + (h₂ : as.set i₁ v₁ = bs) + (h₃ : i₂ < bs.size) + (h₃ : bs.set i₂ v₂ = cs) + (h₄ : j ≠ i₁ ∧ j ≠ i₂) + (h₅ : j < cs.size) + (h₆ : j < as.size) + : cs[j] = as[j] := by + grind