From 2421f7f79941853da14346a234aec6df70cf36a1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 12 Jan 2025 20:38:14 -0800 Subject: [PATCH] feat: exhaustive offset constraint propagation in the `grind` tactic (#6618) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR implements exhaustive offset constraint propagation in the `grind` tactic. This enhancement minimizes the number of case splits performed by `grind`. For instance, it can solve the following example without performing any case splits: ```lean example (p q r s : Prop) (a b : Nat) : (a + 1 ≤ c ↔ p) → (a + 2 ≤ c ↔ s) → (a ≤ c ↔ q) → (a ≤ c + 4 ↔ r) → a ≤ b → b + 2 ≤ c → p ∧ q ∧ r ∧ s := by grind (splits := 0) ``` TODO: support for equational offset constraints. --- src/Init/Grind/Offset.lean | 35 ++++- src/Lean/Meta/Tactic/Grind.lean | 1 + .../Meta/Tactic/Grind/Arith/Internalize.lean | 19 --- src/Lean/Meta/Tactic/Grind/Arith/Main.lean | 4 +- src/Lean/Meta/Tactic/Grind/Arith/Offset.lean | 136 ++++++++++++++---- .../Meta/Tactic/Grind/Arith/ProofUtil.lean | 104 ++++++++++++-- src/Lean/Meta/Tactic/Grind/Arith/Types.lean | 20 ++- src/Lean/Meta/Tactic/Grind/Arith/Util.lean | 30 ++-- tests/lean/run/grind_offset_cnstr.lean | 77 ++++++++++ tests/lean/run/grind_t1.lean | 6 + 10 files changed, 349 insertions(+), 83 deletions(-) diff --git a/src/Init/Grind/Offset.lean b/src/Init/Grind/Offset.lean index 08d17714aefa..1326275b04ac 100644 --- a/src/Init/Grind/Offset.lean +++ b/src/Init/Grind/Offset.lean @@ -8,8 +8,10 @@ import Init.Core import Init.Omega namespace Lean.Grind -def isLt (x y : Nat) : Bool := x < y +abbrev isLt (x y : Nat) : Bool := x < y +abbrev isLE (x y : Nat) : Bool := x ≤ y +/-! Theorems for transitivity. -/ theorem Nat.le_ro (u w v k : Nat) : u ≤ w → w ≤ v + k → u ≤ v + k := by omega theorem Nat.le_lo (u w v k : Nat) : u ≤ w → w + k ≤ v → u + k ≤ v := by @@ -31,6 +33,7 @@ theorem Nat.ro_lo_2 (u w v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ w theorem Nat.ro_ro (u w v k₁ k₂ : Nat) : u ≤ w + k₁ → w ≤ v + k₂ → u ≤ v + (k₁ + k₂) := by omega +/-! Theorems for negating constraints. -/ theorem Nat.of_le_eq_false (u v : Nat) : ((u ≤ v) = False) → v + 1 ≤ u := by simp; omega theorem Nat.of_lo_eq_false_1 (u v : Nat) : ((u + 1 ≤ v) = False) → v ≤ u := by @@ -40,6 +43,7 @@ theorem Nat.of_lo_eq_false (u v k : Nat) : ((u + k ≤ v) = False) → v ≤ u + theorem Nat.of_ro_eq_false (u v k : Nat) : ((u ≤ v + k) = False) → v + (k+1) ≤ u := by simp; omega +/-! Theorems for closing a goal. -/ theorem Nat.unsat_le_lo (u v k : Nat) : isLt 0 k = true → u ≤ v → v + k ≤ u → False := by simp [isLt]; omega theorem Nat.unsat_lo_lo (u v k₁ k₂ : Nat) : isLt 0 (k₁+k₂) = true → u + k₁ ≤ v → v + k₂ ≤ u → False := by @@ -47,4 +51,33 @@ theorem Nat.unsat_lo_lo (u v k₁ k₂ : Nat) : isLt 0 (k₁+k₂) = true → u theorem Nat.unsat_lo_ro (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true → u + k₁ ≤ v → v ≤ u + k₂ → False := by simp [isLt]; omega +/-! Theorems for propagating constraints to `True` -/ +theorem Nat.lo_eq_true_of_lo (u v k₁ k₂ : Nat) : isLE k₂ k₁ = true → u + k₁ ≤ v → (u + k₂ ≤ v) = True := + by simp [isLt]; omega +theorem Nat.le_eq_true_of_lo (u v k : Nat) : u + k ≤ v → (u ≤ v) = True := + by simp; omega +theorem Nat.le_eq_true_of_le (u v : Nat) : u ≤ v → (u ≤ v) = True := + by simp +theorem Nat.ro_eq_true_of_lo (u v k₁ k₂ : Nat) : u + k₁ ≤ v → (u ≤ v + k₂) = True := + by simp; omega +theorem Nat.ro_eq_true_of_le (u v k : Nat) : u ≤ v → (u ≤ v + k) = True := + by simp; omega +theorem Nat.ro_eq_true_of_ro (u v k₁ k₂ : Nat) : isLE k₁ k₂ = true → u ≤ v + k₁ → (u ≤ v + k₂) = True := + by simp [isLE]; omega + +/-! +Theorems for propagating constraints to `False`. +They are variants of the theorems for closing a goal. +-/ +theorem Nat.lo_eq_false_of_le (u v k : Nat) : isLt 0 k = true → u ≤ v → (v + k ≤ u) = False := by + simp [isLt]; omega +theorem Nat.le_eq_false_of_lo (u v k : Nat) : isLt 0 k = true → u + k ≤ v → (v ≤ u) = False := by + simp [isLt]; omega +theorem Nat.lo_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt 0 (k₁+k₂) = true → u + k₁ ≤ v → (v + k₂ ≤ u) = False := by + simp [isLt]; omega +theorem Nat.ro_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true → u + k₁ ≤ v → (v ≤ u + k₂) = False := by + simp [isLt]; omega +theorem Nat.lo_eq_false_of_ro (u v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ v + k₁ → (v + k₂ ≤ u) = False := by + simp [isLt]; omega + end Lean.Grind diff --git a/src/Lean/Meta/Tactic/Grind.lean b/src/Lean/Meta/Tactic/Grind.lean index 4982d12bf8c2..fe32d443e1eb 100644 --- a/src/Lean/Meta/Tactic/Grind.lean +++ b/src/Lean/Meta/Tactic/Grind.lean @@ -47,6 +47,7 @@ builtin_initialize registerTraceClass `grind.offset builtin_initialize registerTraceClass `grind.offset.dist builtin_initialize registerTraceClass `grind.offset.internalize builtin_initialize registerTraceClass `grind.offset.internalize.term (inherited := true) +builtin_initialize registerTraceClass `grind.offset.propagate /-! Trace options for `grind` developers -/ builtin_initialize registerTraceClass `grind.debug diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean b/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean index 96fac2547f80..a0bfc66d9de5 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean @@ -8,25 +8,6 @@ import Lean.Meta.Tactic.Grind.Arith.Offset namespace Lean.Meta.Grind.Arith -namespace Offset - -def internalizeTerm (_e : Expr) (_a : Expr) (_k : Nat) : GoalM Unit := do - -- TODO - return () - -def internalizeCnstr (e : Expr) : GoalM Unit := do - let some c := isNatOffsetCnstr? e | return () - let c := { c with - a := (← mkNode c.a) - b := (← mkNode c.b) - } - trace[grind.offset.internalize] "{e} ↦ {c}" - modify' fun s => { s with - cnstrs := s.cnstrs.insert { expr := e } c - } - -end Offset - def internalize (e : Expr) : GoalM Unit := do Offset.internalizeCnstr e diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Main.lean b/src/Lean/Meta/Tactic/Grind/Arith/Main.lean index d46388c3ae0b..d8f90291b1a7 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Main.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Main.lean @@ -14,12 +14,12 @@ def isCnstr? (e : Expr) : GoalM (Option (Cnstr NodeId)) := return (← get).arith.offset.cnstrs.find? { expr := e } def assertTrue (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do - addEdge c.a c.b c.k (← mkOfEqTrue p) + addEdge c.u c.v c.k (← mkOfEqTrue p) def assertFalse (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do let p := mkOfNegEqFalse (← get').nodes c p let c := c.neg - addEdge c.a c.b c.k p + addEdge c.u c.v c.k p end Offset diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean b/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean index 6a87df5c367b..31b07311e80b 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Offset.lean @@ -50,13 +50,20 @@ def mkNode (expr : Expr) : GoalM NodeId := do } return nodeId +private def getExpr (u : NodeId) : GoalM Expr := do + return (← get').nodes[u]! + private def getDist? (u v : NodeId) : GoalM (Option Int) := do return (← get').targets[u]!.find? v private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do return (← get').proofs[u]!.find? v -partial def extractProof (u v : NodeId) : GoalM Expr := do +/-- +Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the +shortest path between `u` and `v`. +-/ +private partial def mkProofForPath (u v : NodeId) : GoalM Expr := do go (← getProof? u v).get! where go (p : ProofInfo) : GoalM Expr := do @@ -66,29 +73,21 @@ where let p' := (← getProof? u p.w).get! go (mkTrans (← get').nodes p' p v) +/-- +Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t. +it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`, +this function closes the current goal by constructing a proof of `False`. +-/ private def setUnsat (u v : NodeId) (kuv : Int) (huv : Expr) (kvu : Int) : GoalM Unit := do assert! kuv + kvu < 0 - let hvu ← extractProof v u - let u := (← get').nodes[u]! - let v := (← get').nodes[v]! - if kuv == 0 then - assert! kvu < 0 - closeGoal (mkApp6 (mkConst ``Grind.Nat.unsat_le_lo) u v (toExpr (-kvu).toNat) rfl_true huv hvu) - else if kvu == 0 then - assert! kuv < 0 - closeGoal (mkApp6 (mkConst ``Grind.Nat.unsat_le_lo) v u (toExpr (-kuv).toNat) rfl_true hvu huv) - else if kuv < 0 then - if kvu > 0 then - closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_ro) u v (toExpr (-kuv).toNat) (toExpr kvu.toNat) rfl_true huv hvu) - else - assert! kvu < 0 - closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_lo) u v (toExpr (-kuv).toNat) (toExpr (-kvu).toNat) rfl_true huv hvu) - else - assert! kuv > 0 && kvu < 0 - closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_ro) v u (toExpr (-kvu).toNat) (toExpr kuv.toNat) rfl_true hvu huv) + let hvu ← mkProofForPath v u + let u ← getExpr u + let v ← getExpr v + closeGoal (mkUnsatProof u v kuv huv kvu hvu) +/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/ private def setDist (u v : NodeId) (k : Int) : GoalM Unit := do - trace[grind.offset.dist] "{({ a := u, b := v, k : Cnstr NodeId})}" + trace[grind.offset.dist] "{({ u, v, k : Cnstr NodeId})}" modify' fun s => { s with targets := s.targets.modify u fun es => es.insert v k sources := s.sources.modify v fun es => es.insert u k @@ -107,17 +106,80 @@ private def forEachSourceOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : G private def forEachTargetOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : GoalM Unit := do (← get').targets[u]!.forM f +/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/ private def isShorter (u v : NodeId) (k : Int) : GoalM Bool := do if let some k' ← getDist? u v then return k < k' else return true +/-- +Tries to assign `e` to `True`, which is represented by constraint `c` (from `u` to `v`), using the +path `u --(k)--> v`. +-/ +private def propagateTrue (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do + if k ≤ c.k then + trace[grind.offset.propagate] "{{ u, v, k : Cnstr NodeId}} ==> {e} = True" + let kuv ← mkProofForPath u v + let u ← getExpr u + let v ← getExpr v + pushEqTrue e <| mkPropagateEqTrueProof u v k kuv c.k + return true + return false + +example (x y : Nat) : x + 2 ≤ y → ¬ (y ≤ x + 1) := by omega + +/-- +Tries to assign `e` to `False`, which is represented by constraint `c` (from `v` to `u`), using the +path `u --(k)--> v`. +-/ +private def propagateFalse (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do + if k + c.k < 0 then + trace[grind.offset.propagate] "{{ u, v, k : Cnstr NodeId}} ==> {e} = False" + let kuv ← mkProofForPath u v + let u ← getExpr u + let v ← getExpr v + pushEqFalse e <| mkPropagateEqFalseProof u v k kuv c.k + return false + +/-- +Auxiliary function for implementing `propagateAll`. +Traverses the constraints `c` (representing an expression `e`) s.t. +`c.u = u` and `c.v = v`, it removes `c` from the list of constraints +associated with `(u, v)` IF +- `e` is already assigned, or +- `f c e` returns true +-/ +@[inline] +private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → GoalM Bool) : GoalM Unit := do + if let some cs := (← get').cnstrsOf.find? (u, v) then + let cs' ← cs.filterM fun (c, e) => do + if (← isEqTrue e <||> isEqFalse e) then + return false -- constraint was already assigned + else + return !(← f c e) + modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' } + +/-- Performs constraint propagation. -/ +private def propagateAll (u v : NodeId) (k : Int) : GoalM Unit := do + updateCnstrsOf u v fun c e => return !(← propagateTrue u v k c e) + updateCnstrsOf v u fun c e => return !(← propagateFalse u v k c e) + +/-- +If `isShorter u v k`, updates the shortest distance between `u` and `v`. +`w` is the penultimate node in the path from `u` to `v`. +-/ private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : GoalM Unit := do if (← isShorter u v k) then setDist u v k setProof u v (← getProof? w v).get! + propagateAll u v k +/-- +Adds an edge `u --(k) --> v` justified by the proof term `p`, and then +if no negative cycle was created, updates the shortest distance of affected +node pairs. +-/ def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do if (← isInconsistent) then return () if let some k' ← getDist? v u then @@ -127,6 +189,7 @@ def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do if (← isShorter u v k) then setDist u v k setProof u v { w := u, k, proof := p } + propagateAll u v k update where update : GoalM Unit := do @@ -140,6 +203,25 @@ where /- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/ updateIfShorter i j (k₁+k+k₂) v +def internalizeCnstr (e : Expr) : GoalM Unit := do + let some c := isNatOffsetCnstr? e | return () + let u ← mkNode c.u + let v ← mkNode c.v + let c := { c with u, v } + if let some k ← getDist? u v then + if (← propagateTrue u v k c e) then + return () + if let some k ← getDist? v u then + if (← propagateFalse v u k c e) then + return () + trace[grind.offset.internalize] "{e} ↦ {c}" + modify' fun s => { s with + cnstrs := s.cnstrs.insert { expr := e } c + cnstrsOf := + let cs := if let some cs := s.cnstrsOf.find? (u, v) then (c, e) :: cs else [(c, e)] + s.cnstrsOf.insert (u, v) cs + } + def traceDists : GoalM Unit := do let s ← get' for u in [:s.targets.size], es in s.targets.toArray do @@ -147,23 +229,23 @@ def traceDists : GoalM Unit := do trace[grind.offset.dist] "#{u} -({k})-> #{v}" def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do - let a := (← get').nodes[c.a]! - let b := (← get').nodes[c.b]! + let u := (← get').nodes[c.u]! + let v := (← get').nodes[c.v]! let mk := if c.le then mkNatLE else mkNatEq if c.k == 0 then - return mk a b + return mk u v else if c.k < 0 then - return mk (mkNatAdd a (Lean.toExpr ((-c.k).toNat))) b + return mk (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v else - return mk a (mkNatAdd b (Lean.toExpr c.k.toNat)) + return mk u (mkNatAdd v (Lean.toExpr c.k.toNat)) def checkInvariants : GoalM Unit := do let s ← get' for u in [:s.targets.size], es in s.targets.toArray do for (v, k) in es do - let c : Cnstr NodeId := { a := u, b := v, k } + let c : Cnstr NodeId := { u, v, k } trace[grind.debug.offset] "{c}" - let p ← extractProof u v + let p ← mkProofForPath u v trace[grind.debug.offset.proof] "{p} : {← inferType p}" check p unless (← withDefault <| isDefEq (← inferType p) (← Cnstr.toExpr c)) do diff --git a/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean b/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean index 602814cbdfd6..d45b266df2b8 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/ProofUtil.lean @@ -19,6 +19,10 @@ namespace Offset /-- Returns a proof for `true = true` -/ def rfl_true : Expr := mkConst ``Grind.rfl_true +private def toExprN (n : Int) := + assert! n >= 0 + toExpr n.toNat + open Lean.Grind in /-- Assume `pi₁` is `{ w := u, k := k₁, proof := p₁ }` and `pi₂` is `{ w := w, k := k₂, proof := p₂ }` @@ -37,53 +41,127 @@ def mkTrans (nodes : PArray Expr) (pi₁ : ProofInfo) (pi₂ : ProofInfo) (v : N mkApp5 (mkConst ``Nat.le_trans) u w v p₁ p₂ else if k₂ > 0 then -- u ≤ v, w ≤ v + k₂ - mkApp6 (mkConst ``Nat.le_ro) u w v (toExpr k₂.toNat) p₁ p₂ + mkApp6 (mkConst ``Nat.le_ro) u w v (toExprN k₂) p₁ p₂ else let k₂ := - k₂ -- u ≤ w, w + k₂ ≤ v - mkApp6 (mkConst ``Nat.le_lo) u w v (toExpr k₂.toNat) p₁ p₂ + mkApp6 (mkConst ``Nat.le_lo) u w v (toExprN k₂) p₁ p₂ else if k₁ < 0 then let k₁ := -k₁ if k₂ == 0 then - mkApp6 (mkConst ``Nat.lo_le) u w v (toExpr k₁.toNat) p₁ p₂ + mkApp6 (mkConst ``Nat.lo_le) u w v (toExprN k₁) p₁ p₂ else if k₂ < 0 then let k₂ := -k₂ - mkApp7 (mkConst ``Nat.lo_lo) u w v (toExpr k₁.toNat) (toExpr k₂.toNat) p₁ p₂ + mkApp7 (mkConst ``Nat.lo_lo) u w v (toExprN k₁) (toExprN k₂) p₁ p₂ else - let ke₁ := toExpr k₁.toNat - let ke₂ := toExpr k₂.toNat + let ke₁ := toExprN k₁ + let ke₂ := toExprN k₂ if k₁ > k₂ then mkApp8 (mkConst ``Nat.lo_ro_1) u w v ke₁ ke₂ rfl_true p₁ p₂ else mkApp7 (mkConst ``Nat.lo_ro_2) u w v ke₁ ke₂ p₁ p₂ else - let ke₁ := toExpr k₁.toNat + let ke₁ := toExprN k₁ if k₂ == 0 then mkApp6 (mkConst ``Nat.ro_le) u w v ke₁ p₁ p₂ else if k₂ < 0 then let k₂ := -k₂ - let ke₂ := toExpr k₂.toNat + let ke₂ := toExprN k₂ if k₂ > k₁ then mkApp8 (mkConst ``Nat.ro_lo_2) u w v ke₁ ke₂ rfl_true p₁ p₂ else mkApp7 (mkConst ``Nat.ro_lo_1) u w v ke₁ ke₂ p₁ p₂ else - let ke₂ := toExpr k₂.toNat + let ke₂ := toExprN k₂ mkApp7 (mkConst ``Nat.ro_ro) u w v ke₁ ke₂ p₁ p₂ { w := pi₁.w, k := k₁+k₂, proof := p } open Lean.Grind in def mkOfNegEqFalse (nodes : PArray Expr) (c : Cnstr NodeId) (h : Expr) : Expr := - let u := nodes[c.a]! - let v := nodes[c.b]! + let u := nodes[c.u]! + let v := nodes[c.v]! if c.k == 0 then mkApp3 (mkConst ``Nat.of_le_eq_false) u v h else if c.k == -1 && c.le then mkApp3 (mkConst ``Nat.of_lo_eq_false_1) u v h else if c.k < 0 then - mkApp4 (mkConst ``Nat.of_lo_eq_false) u v (toExpr (-c.k).toNat) h + mkApp4 (mkConst ``Nat.of_lo_eq_false) u v (toExprN (-c.k)) h + else + mkApp4 (mkConst ``Nat.of_ro_eq_false) u v (toExprN c.k) h + +/-- +Returns a proof of `False` using a negative cycle composed of +- `u --(kuv)--> v` with proof `huv` +- `v --(kvu)--> u` with proof `hvu` +-/ +def mkUnsatProof (u v : Expr) (kuv : Int) (huv : Expr) (kvu : Int) (hvu : Expr) : Expr := + if kuv == 0 then + assert! kvu < 0 + mkApp6 (mkConst ``Grind.Nat.unsat_le_lo) u v (toExprN (-kvu)) rfl_true huv hvu + else if kvu == 0 then + mkApp6 (mkConst ``Grind.Nat.unsat_le_lo) v u (toExprN (-kuv)) rfl_true hvu huv + else if kuv < 0 then + if kvu > 0 then + mkApp7 (mkConst ``Grind.Nat.unsat_lo_ro) u v (toExprN (-kuv)) (toExprN kvu) rfl_true huv hvu + else + assert! kvu < 0 + mkApp7 (mkConst ``Grind.Nat.unsat_lo_lo) u v (toExprN (-kuv)) (toExprN (-kvu)) rfl_true huv hvu + else + assert! kuv > 0 && kvu < 0 + mkApp7 (mkConst ``Grind.Nat.unsat_lo_ro) v u (toExprN (-kvu)) (toExprN kuv) rfl_true hvu huv + +/-- +Given a path `u --(kuv)--> v` justified by proof `huv`, +construct a proof of `e = True` where `e` is a term corresponding to the edgen `u --(k') --> v` +s.t. `k ≤ k'` +-/ +def mkPropagateEqTrueProof (u v : Expr) (k : Int) (huv : Expr) (k' : Int) : Expr := + if k == 0 then + if k' == 0 then + mkApp3 (mkConst ``Grind.Nat.le_eq_true_of_le) u v huv + else + assert! k' > 0 + mkApp4 (mkConst ``Grind.Nat.ro_eq_true_of_le) u v (toExprN k') huv + else if k < 0 then + let k := -k + if k' == 0 then + mkApp4 (mkConst ``Grind.Nat.le_eq_true_of_lo) u v (toExprN k) huv + else if k' < 0 then + let k' := -k' + mkApp6 (mkConst ``Grind.Nat.lo_eq_true_of_lo) u v (toExprN k) (toExprN k') rfl_true huv + else + assert! k' > 0 + mkApp5 (mkConst ``Grind.Nat.ro_eq_true_of_lo) u v (toExprN k) (toExprN k') huv + else + assert! k > 0 + assert! k' > 0 + mkApp6 (mkConst ``Grind.Nat.ro_eq_true_of_ro) u v (toExprN k) (toExprN k') rfl_true huv + +/-- +Given a path `u --(kuv)--> v` justified by proof `huv`, +construct a proof of `e = False` where `e` is a term corresponding to the edgen `v --(k') --> u` +s.t. `k+k' < 0` +-/ +def mkPropagateEqFalseProof (u v : Expr) (k : Int) (huv : Expr) (k' : Int) : Expr := + if k == 0 then + assert! k' < 0 + let k' := -k' + mkApp5 (mkConst ``Grind.Nat.lo_eq_false_of_le) u v (toExprN k') rfl_true huv + else if k < 0 then + let k := -k + if k' == 0 then + mkApp5 (mkConst ``Grind.Nat.le_eq_false_of_lo) u v (toExprN k) rfl_true huv + else if k' < 0 then + let k' := -k' + mkApp6 (mkConst ``Grind.Nat.lo_eq_false_of_lo) u v (toExprN k) (toExprN k') rfl_true huv + else + assert! k' > 0 + mkApp6 (mkConst ``Grind.Nat.ro_eq_false_of_lo) u v (toExprN k) (toExprN k') rfl_true huv else - mkApp4 (mkConst ``Nat.of_ro_eq_false) u v (toExpr c.k.toNat) h + assert! k > 0 + assert! k' < 0 + let k' := -k' + mkApp6 (mkConst ``Grind.Nat.lo_eq_false_of_ro) u v (toExprN k) (toExprN k') rfl_true huv end Offset diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Types.lean b/src/Lean/Meta/Tactic/Grind/Arith/Types.lean index d4f6d5497b48..3e438bda6270 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Types.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Types.lean @@ -26,26 +26,34 @@ structure ProofInfo where /-- State of the constraint offset procedure. -/ structure State where - nodes : PArray Expr := {} - nodeMap : PHashMap ENodeKey NodeId := {} - cnstrs : PHashMap ENodeKey (Cnstr NodeId) := {} + /-- Mapping from `NodeId` to the `Expr` represented by the node. -/ + nodes : PArray Expr := {} + /-- Mapping from `Expr` to a node representing it. -/ + nodeMap : PHashMap ENodeKey NodeId := {} + /-- Mapping from `Expr` representing inequalites to constraints. -/ + cnstrs : PHashMap ENodeKey (Cnstr NodeId) := {} + /-- + Mapping from pairs `(u, v)` to a list of offset constraints on `u` and `v`. + We use this mapping to implement exhaustive constraint propagation. + -/ + cnstrsOf : PHashMap (NodeId × NodeId) (List (Cnstr NodeId × Expr)) := {} /-- For each node with id `u`, `sources[u]` contains pairs `(v, k)` s.t. there is a path from `v` to `u` with weight `k`. -/ - sources : PArray (AssocList NodeId Int) := {} + sources : PArray (AssocList NodeId Int) := {} /-- For each node with id `u`, `targets[u]` contains pairs `(v, k)` s.t. there is a path from `u` to `v` with weight `k`. -/ - targets : PArray (AssocList NodeId Int) := {} + targets : PArray (AssocList NodeId Int) := {} /-- Proof reconstruction information. For each node with id `u`, `proofs[u]` contains pairs `(v, { w, proof })` s.t. there is a path from `u` to `v`, and `w` is the penultimate node in the path, and `proof` is the justification for the last edge. -/ - proofs : PArray (AssocList NodeId ProofInfo) := {} + proofs : PArray (AssocList NodeId ProofInfo) := {} deriving Inhabited end Offset diff --git a/src/Lean/Meta/Tactic/Grind/Arith/Util.lean b/src/Lean/Meta/Tactic/Grind/Arith/Util.lean index 13e0ae6de46f..f3da57f4c750 100644 --- a/src/Lean/Meta/Tactic/Grind/Arith/Util.lean +++ b/src/Lean/Meta/Tactic/Grind/Arith/Util.lean @@ -47,26 +47,26 @@ def isNatOffset? (e : Expr) : Option (Expr × Nat) := Id.run do /-- An offset constraint. -/ structure Offset.Cnstr (α : Type) where - a : α - b : α + u : α + v : α k : Int := 0 le : Bool := true deriving Inhabited def Offset.Cnstr.neg : Cnstr α → Cnstr α - | { a, b, k, le } => { a := b, b := a, le, k := -k - 1 } + | { u, v, k, le } => { u := v, v := u, le, k := -k - 1 } example (c : Offset.Cnstr α) : c.neg.neg = c := by cases c; simp [Offset.Cnstr.neg]; omega def Offset.toMessageData [inst : ToMessageData α] (c : Offset.Cnstr α) : MessageData := match c.k, c.le with - | .ofNat 0, true => m!"{c.a} ≤ {c.b}" - | .ofNat 0, false => m!"{c.a} = {c.b}" - | .ofNat k, true => m!"{c.a} ≤ {c.b} + {k}" - | .ofNat k, false => m!"{c.a} = {c.b} + {k}" - | .negSucc k, true => m!"{c.a} + {k + 1} ≤ {c.b}" - | .negSucc k, false => m!"{c.a} + {k + 1} = {c.b}" + | .ofNat 0, true => m!"{c.u} ≤ {c.v}" + | .ofNat 0, false => m!"{c.u} = {c.v}" + | .ofNat k, true => m!"{c.u} ≤ {c.v} + {k}" + | .ofNat k, false => m!"{c.u} = {c.v} + {k}" + | .negSucc k, true => m!"{c.u} + {k + 1} ≤ {c.v}" + | .negSucc k, false => m!"{c.u} + {k + 1} = {c.v}" instance : ToMessageData (Offset.Cnstr Expr) where toMessageData c := Offset.toMessageData c @@ -78,12 +78,12 @@ def isNatOffsetCnstr? (e : Expr) : Option (Offset.Cnstr Expr) := | Eq α a b => if isNatType α then go a b false else none | _ => none where - go (a b : Expr) (le : Bool) := - if let some (a, k) := isNatOffset? a then - some { a, k := - k, b, le } - else if let some (b, k) := isNatOffset? b then - some { a, b, k := k, le } + go (u v : Expr) (le : Bool) := + if let some (u, k) := isNatOffset? u then + some { u, k := - k, v, le } + else if let some (v, k) := isNatOffset? v then + some { u, v, k := k, le } else - some { a, b, le } + some { u, v, le } end Lean.Meta.Grind.Arith diff --git a/tests/lean/run/grind_offset_cnstr.lean b/tests/lean/run/grind_offset_cnstr.lean index cb41c2c76242..0e4d960ca984 100644 --- a/tests/lean/run/grind_offset_cnstr.lean +++ b/tests/lean/run/grind_offset_cnstr.lean @@ -275,3 +275,80 @@ fun {a4} p a1 a2 a3 => #guard_msgs (info) in open Lean Grind in #print ex1 + +/-! Propagate `cnstr = False` tests -/ + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p q r s : Prop) (a b : Nat) : a ≤ b → b + 2 ≤ c → (a + 1 ≤ c ↔ p) → (a + 2 ≤ c ↔ s) → (a ≤ c ↔ q) → (a ≤ c + 4 ↔ r) → p ∧ q ∧ r ∧ s := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p q : Prop) (a b : Nat) : a ≤ b → b ≤ c → (a ≤ c ↔ p) → (a ≤ c + 1 ↔ q) → p ∧ q := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p q : Prop) (a b : Nat) : a ≤ b → b ≤ c + 1 → (a ≤ c + 1 ↔ p) → (a ≤ c + 2 ↔ q) → p ∧ q := by + grind (splits := 0) + + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p r s : Prop) (a b : Nat) : a ≤ b → b + 2 ≤ c → (c ≤ a ↔ p) → (c ≤ a + 1 ↔ s) → (c + 1 ≤ a ↔ r) → ¬p ∧ ¬r ∧ ¬s := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p r : Prop) (a b : Nat) : a ≤ b → b ≤ c → (c + 1 ≤ a ↔ p) → (c + 2 ≤ a + 1 ↔ r) → ¬p ∧ ¬r := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p r : Prop) (a b : Nat) : a ≤ b → b ≤ c + 3 → (c + 5 ≤ a ↔ p) → (c + 4 ≤ a ↔ r) → ¬p ∧ ¬r := by + grind (splits := 0) + +/-! Propagate `cnstr = False` tests, but with different internalization order -/ + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p q r s : Prop) (a b : Nat) : (a + 1 ≤ c ↔ p) → (a + 2 ≤ c ↔ s) → (a ≤ c ↔ q) → (a ≤ c + 4 ↔ r) → a ≤ b → b + 2 ≤ c → p ∧ q ∧ r ∧ s := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p q : Prop) (a b : Nat) : (a ≤ c ↔ p) → (a ≤ c + 1 ↔ q) → a ≤ b → b ≤ c → p ∧ q := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p q : Prop) (a b : Nat) : (a ≤ c + 1 ↔ p) → (a ≤ c + 2 ↔ q) → a ≤ b → b ≤ c + 1 → p ∧ q := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p r s : Prop) (a b : Nat) : (c ≤ a ↔ p) → (c ≤ a + 1 ↔ s) → (c + 1 ≤ a ↔ r) → a ≤ b → b + 2 ≤ c → ¬p ∧ ¬r ∧ ¬s := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p r : Prop) (a b : Nat) : (c + 1 ≤ a ↔ p) → (c + 2 ≤ a + 1 ↔ r) → a ≤ b → b ≤ c → ¬p ∧ ¬r := by + grind (splits := 0) + +-- The following example is solved by `grind` using constraint propagation and 0 case-splits. +#guard_msgs (info) in +set_option trace.grind.split true in +example (p r : Prop) (a b : Nat) : (c + 5 ≤ a ↔ p) → (c + 4 ≤ a ↔ r) → a ≤ b → b ≤ c + 3 → ¬p ∧ ¬r := by + grind (splits := 0) diff --git a/tests/lean/run/grind_t1.lean b/tests/lean/run/grind_t1.lean index 1752266b8e31..d660f51b09d3 100644 --- a/tests/lean/run/grind_t1.lean +++ b/tests/lean/run/grind_t1.lean @@ -263,3 +263,9 @@ a✝ : p set_option trace.grind.split true in example (p q : Prop) : ¬(p ↔ q) → p → False := by grind -- should not split on (p ↔ q) + +example {a b : Nat} (h : a < b) : ¬ b < a := by + grind + +example {m n : Nat} : m < n ↔ m ≤ n ∧ ¬ n ≤ m := by + grind