Skip to content

Commit

Permalink
feat: custom congruence rule for equality in grind (#6510)
Browse files Browse the repository at this point in the history
This PR adds a custom congruence rule for equality in `grind`. The new
rule takes into account that `Eq` is a symmetric relation. In the
future, we will add support for arbitrary symmetric relations. The
current rule is important for propagating disequalities effectively in
`grind`.
  • Loading branch information
leodemoura authored Jan 2, 2025
1 parent e46b5f3 commit 9d62227
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 38 deletions.
3 changes: 3 additions & 0 deletions src/Init/Grind/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ theorem false_of_not_eq_self {a : Prop} (h : (Not a) = a) : False := by
theorem eq_eq_of_eq_true_left {a b : Prop} (h : a = True) : (a = b) = b := by simp [h]
theorem eq_eq_of_eq_true_right {a b : Prop} (h : b = True) : (a = b) = a := by simp [h]

theorem eq_congr {α : Sort u} {a₁ b₁ a₂ b₂ : α} (h₁ : a₁ = a₂) (h₂ : b₁ = b₂) : (a₁ = b₁) = (a₂ = b₂) := by simp [*]
theorem eq_congr' {α : Sort u} {a₁ b₁ a₂ b₂ : α} (h₁ : a₁ = b₂) (h₂ : b₁ = a₂) : (a₁ = b₁) = (a₂ = b₂) := by rw [h₁, h₂, Eq.comm (a := a₂)]

/-! Forall -/

theorem forall_propagator (p : Prop) (q : p → Prop) (q' : Prop) (h₁ : p = True) (h₂ : q (of_eq_true h₁) = q') : (∀ hp : p, q hp) = q' := by
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ builtin_initialize registerTraceClass `grind.debug.congr
builtin_initialize registerTraceClass `grind.debug.proof
builtin_initialize registerTraceClass `grind.debug.proj
builtin_initialize registerTraceClass `grind.debug.parent
builtin_initialize registerTraceClass `grind.debug.final

end Lean
78 changes: 50 additions & 28 deletions src/Lean/Meta/Tactic/Grind/Proof.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Sorry -- TODO: remove
import Init.Grind.Lemmas
import Lean.Meta.Tactic.Grind.Types

namespace Lean.Meta.Grind
Expand Down Expand Up @@ -128,6 +128,52 @@ mutual
let r := (← loop lhs rhs).get!
if heq then mkHEqOfEq r else return r

private partial def mkHCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
let f := lhs.getAppFn
let g := rhs.getAppFn
let numArgs := lhs.getAppNumArgs
assert! rhs.getAppNumArgs == numArgs
let thm ← mkHCongrWithArity f numArgs
assert! thm.argKinds.size == numArgs
let rec loop (lhs rhs : Expr) (i : Nat) : GoalM Expr := do
let i := i - 1
if lhs.isApp then
let proof ← loop lhs.appFn! rhs.appFn! i
let a₁ := lhs.appArg!
let a₂ := rhs.appArg!
let k := thm.argKinds[i]!
return mkApp3 proof a₁ a₂ (← mkEqProofCore a₁ a₂ (k matches .heq))
else
return thm.proof
let proof ← loop lhs rhs numArgs
if isSameExpr f g then
mkEqOfHEqIfNeeded proof heq
else
/-
`lhs` is of the form `f a_1 ... a_n`
`rhs` is of the form `g b_1 ... b_n`
`proof : HEq (f a_1 ... a_n) (f b_1 ... b_n)`
We construct a proof for `HEq (f a_1 ... a_n) (g b_1 ... b_n)` using `Eq.ndrec`
-/
let motive ← withLocalDeclD (← mkFreshUserName `x) (← inferType f) fun x => do
mkLambdaFVars #[x] (← mkHEq lhs (mkAppN x rhs.getAppArgs))
let fEq ← mkEqProofCore f g false
let proof ← mkEqNDRec motive proof fEq
mkEqOfHEqIfNeeded proof heq

private partial def mkEqCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
let_expr f@Eq α₁ a₁ b₁ := lhs | unreachable!
let_expr Eq α₂ a₂ b₂ := rhs | unreachable!
let enodes := (← get).enodes
let us := f.constLevels!
if !isSameExpr α₁ α₂ then
mkHCongrProof lhs rhs heq
else if hasSameRoot enodes a₁ a₂ && hasSameRoot enodes b₁ b₂ then
return mkApp7 (mkConst ``Grind.eq_congr us) α₁ a₁ b₁ a₂ b₂ (← mkEqProofCore a₁ a₂ false) (← mkEqProofCore b₁ b₂ false)
else
assert! hasSameRoot enodes a₁ b₂ && hasSameRoot enodes b₁ a₂
return mkApp7 (mkConst ``Grind.eq_congr' us) α₁ a₁ b₁ a₂ b₂ (← mkEqProofCore a₁ b₂ false) (← mkEqProofCore b₁ a₂ false)

/-- Constructs a congruence proof for `lhs` and `rhs`. -/
private partial def mkCongrProof (lhs rhs : Expr) (heq : Bool) : GoalM Expr := do
let f := lhs.getAppFn
Expand All @@ -136,36 +182,12 @@ mutual
assert! rhs.getAppNumArgs == numArgs
if f.isConstOf ``Lean.Grind.nestedProof && g.isConstOf ``Lean.Grind.nestedProof && numArgs == 2 then
mkNestedProofCongr lhs rhs heq
else if f.isConstOf ``Eq && g.isConstOf ``Eq && numArgs == 3 then
mkEqCongrProof lhs rhs heq
else if (← isCongrDefaultProofTarget lhs rhs f g numArgs) then
mkCongrDefaultProof lhs rhs heq
else
let thm ← mkHCongrWithArity f numArgs
assert! thm.argKinds.size == numArgs
let rec loop (lhs rhs : Expr) (i : Nat) : GoalM Expr := do
let i := i - 1
if lhs.isApp then
let proof ← loop lhs.appFn! rhs.appFn! i
let a₁ := lhs.appArg!
let a₂ := rhs.appArg!
let k := thm.argKinds[i]!
return mkApp3 proof a₁ a₂ (← mkEqProofCore a₁ a₂ (k matches .heq))
else
return thm.proof
let proof ← loop lhs rhs numArgs
if isSameExpr f g then
mkEqOfHEqIfNeeded proof heq
else
/-
`lhs` is of the form `f a_1 ... a_n`
`rhs` is of the form `g b_1 ... b_n`
`proof : HEq (f a_1 ... a_n) (f b_1 ... b_n)`
We construct a proof for `HEq (f a_1 ... a_n) (g b_1 ... b_n)` using `Eq.ndrec`
-/
let motive ← withLocalDeclD (← mkFreshUserName `x) (← inferType f) fun x => do
mkLambdaFVars #[x] (← mkHEq lhs (mkAppN x rhs.getAppArgs))
let fEq ← mkEqProofCore f g false
let proof ← mkEqNDRec motive proof fEq
mkEqOfHEqIfNeeded proof heq
mkHCongrProof lhs rhs heq

private partial def realizeEqProof (lhs rhs : Expr) (h : Expr) (flipped : Bool) (heq : Bool) : GoalM Expr := do
let h ← if h == congrPlaceholderProof then
Expand Down
34 changes: 24 additions & 10 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ private def hashRoot (enodes : ENodeMap) (e : Expr) : UInt64 :=
else
13

private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do
def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do
if isSameExpr a b then
return true
else
Expand All @@ -258,24 +258,38 @@ private def hasSameRoot (enodes : ENodeMap) (a b : Expr) : Bool := Id.run do
isSameExpr n1.root n2.root

def congrHash (enodes : ENodeMap) (e : Expr) : UInt64 :=
if e.isAppOfArity ``Lean.Grind.nestedProof 2 then
-- We only hash the proposition
hashRoot enodes (e.getArg! 0)
else
go e 17
match_expr e with
| Grind.nestedProof p _ => hashRoot enodes p
| Eq _ lhs rhs => goEq lhs rhs
| _ => go e 17
where
goEq (lhs rhs : Expr) : UInt64 :=
let h₁ := hashRoot enodes lhs
let h₂ := hashRoot enodes rhs
if h₁ > h₂ then mixHash h₂ h₁ else mixHash h₁ h₂
go (e : Expr) (r : UInt64) : UInt64 :=
match e with
| .app f a => go f (mixHash r (hashRoot enodes a))
| _ => mixHash r (hashRoot enodes e)

/-- Returns `true` if `a` and `b` are congruent modulo the equivalence classes in `enodes`. -/
partial def isCongruent (enodes : ENodeMap) (a b : Expr) : Bool :=
if a.isAppOfArity ``Lean.Grind.nestedProof 2 && b.isAppOfArity ``Lean.Grind.nestedProof 2 then
hasSameRoot enodes (a.getArg! 0) (b.getArg! 0)
else
go a b
match_expr a with
| Grind.nestedProof p₁ _ =>
let_expr Grind.nestedProof p₂ _ := b | false
hasSameRoot enodes p₁ p₂
| Eq α₁ lhs₁ rhs₁ =>
let_expr Eq α₂ lhs₂ rhs₂ := b | false
if isSameExpr α₁ α₂ then
goEq lhs₁ rhs₁ lhs₂ rhs₂
else
go a b
| _ => go a b
where
goEq (lhs₁ rhs₁ lhs₂ rhs₂ : Expr) : Bool :=
(hasSameRoot enodes lhs₁ lhs₂ && hasSameRoot enodes rhs₁ rhs₂)
||
(hasSameRoot enodes lhs₁ rhs₂ && hasSameRoot enodes rhs₁ lhs₂)
go (a b : Expr) : Bool :=
if a.isApp && b.isApp then
hasSameRoot enodes a.appArg! b.appArg! && go a.appFn! b.appFn!
Expand Down
5 changes: 5 additions & 0 deletions tests/lean/run/grind_diseq.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set_option grind.debug true

example (p q : Prop) (a b c d : Nat) :
a = b → c = d → a ≠ c → (d ≠ b → p) → (d ≠ b → q) → p ∧ q := by
grind (splits:=0)
12 changes: 12 additions & 0 deletions tests/lean/run/grind_ematch2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,15 @@ example (as bs cs : Array α) (v₁ v₂ : α)
(h₆ : j < as.size)
: cs[j] = as[j] := by
grind

example (as bs cs : Array α) (v₁ v₂ : α)
(i₁ i₂ j : Nat)
(h₁ : i₁ < as.size)
(h₂ : as.set i₁ v₁ = bs)
(h₃ : i₂ < bs.size)
(h₃ : bs.set i₂ v₂ = cs)
(h₄ : j ≠ i₁ ∧ j ≠ i₂)
(h₅ : j < cs.size)
(h₆ : j < as.size)
: cs[j] = as[j] := by
grind

0 comments on commit 9d62227

Please sign in to comment.