Skip to content

Commit

Permalink
feat: exhaustive offset constraint propagation in the grind tactic (#…
Browse files Browse the repository at this point in the history
…6618)

This PR implements exhaustive offset constraint propagation in the
`grind` tactic. This enhancement minimizes the number of case splits
performed by `grind`. For instance, it can solve the following example
without performing any case splits:

```lean
example (p q r s : Prop) (a b : Nat) : (a + 1 ≤ c ↔ p) → (a + 2 ≤ c ↔ s) → (a ≤ c ↔ q) → (a ≤ c + 4 ↔ r) → a ≤ b → b + 2 ≤ c → p ∧ q ∧ r ∧ s := by
  grind (splits := 0)
```

TODO: support for equational offset constraints.
  • Loading branch information
leodemoura authored Jan 13, 2025
1 parent 40efbb9 commit 2421f7f
Show file tree
Hide file tree
Showing 10 changed files with 349 additions and 83 deletions.
35 changes: 34 additions & 1 deletion src/Init/Grind/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import Init.Core
import Init.Omega

namespace Lean.Grind
def isLt (x y : Nat) : Bool := x < y
abbrev isLt (x y : Nat) : Bool := x < y
abbrev isLE (x y : Nat) : Bool := x ≤ y

/-! Theorems for transitivity. -/
theorem Nat.le_ro (u w v k : Nat) : u ≤ w → w ≤ v + k → u ≤ v + k := by
omega
theorem Nat.le_lo (u w v k : Nat) : u ≤ w → w + k ≤ v → u + k ≤ v := by
Expand All @@ -31,6 +33,7 @@ theorem Nat.ro_lo_2 (u w v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ w
theorem Nat.ro_ro (u w v k₁ k₂ : Nat) : u ≤ w + k₁ → w ≤ v + k₂ → u ≤ v + (k₁ + k₂) := by
omega

/-! Theorems for negating constraints. -/
theorem Nat.of_le_eq_false (u v : Nat) : ((u ≤ v) = False) → v + 1 ≤ u := by
simp; omega
theorem Nat.of_lo_eq_false_1 (u v : Nat) : ((u + 1 ≤ v) = False) → v ≤ u := by
Expand All @@ -40,11 +43,41 @@ theorem Nat.of_lo_eq_false (u v k : Nat) : ((u + k ≤ v) = False) → v ≤ u +
theorem Nat.of_ro_eq_false (u v k : Nat) : ((u ≤ v + k) = False) → v + (k+1) ≤ u := by
simp; omega

/-! Theorems for closing a goal. -/
theorem Nat.unsat_le_lo (u v k : Nat) : isLt 0 k = true → u ≤ v → v + k ≤ u → False := by
simp [isLt]; omega
theorem Nat.unsat_lo_lo (u v k₁ k₂ : Nat) : isLt 0 (k₁+k₂) = true → u + k₁ ≤ v → v + k₂ ≤ u → False := by
simp [isLt]; omega
theorem Nat.unsat_lo_ro (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true → u + k₁ ≤ v → v ≤ u + k₂ → False := by
simp [isLt]; omega

/-! Theorems for propagating constraints to `True` -/
theorem Nat.lo_eq_true_of_lo (u v k₁ k₂ : Nat) : isLE k₂ k₁ = true → u + k₁ ≤ v → (u + k₂ ≤ v) = True :=
by simp [isLt]; omega
theorem Nat.le_eq_true_of_lo (u v k : Nat) : u + k ≤ v → (u ≤ v) = True :=
by simp; omega
theorem Nat.le_eq_true_of_le (u v : Nat) : u ≤ v → (u ≤ v) = True :=
by simp
theorem Nat.ro_eq_true_of_lo (u v k₁ k₂ : Nat) : u + k₁ ≤ v → (u ≤ v + k₂) = True :=
by simp; omega
theorem Nat.ro_eq_true_of_le (u v k : Nat) : u ≤ v → (u ≤ v + k) = True :=
by simp; omega
theorem Nat.ro_eq_true_of_ro (u v k₁ k₂ : Nat) : isLE k₁ k₂ = true → u ≤ v + k₁ → (u ≤ v + k₂) = True :=
by simp [isLE]; omega

/-!
Theorems for propagating constraints to `False`.
They are variants of the theorems for closing a goal.
-/
theorem Nat.lo_eq_false_of_le (u v k : Nat) : isLt 0 k = true → u ≤ v → (v + k ≤ u) = False := by
simp [isLt]; omega
theorem Nat.le_eq_false_of_lo (u v k : Nat) : isLt 0 k = true → u + k ≤ v → (v ≤ u) = False := by
simp [isLt]; omega
theorem Nat.lo_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt 0 (k₁+k₂) = true → u + k₁ ≤ v → (v + k₂ ≤ u) = False := by
simp [isLt]; omega
theorem Nat.ro_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true → u + k₁ ≤ v → (v ≤ u + k₂) = False := by
simp [isLt]; omega
theorem Nat.lo_eq_false_of_ro (u v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ v + k₁ → (v + k₂ ≤ u) = False := by
simp [isLt]; omega

end Lean.Grind
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ builtin_initialize registerTraceClass `grind.offset
builtin_initialize registerTraceClass `grind.offset.dist
builtin_initialize registerTraceClass `grind.offset.internalize
builtin_initialize registerTraceClass `grind.offset.internalize.term (inherited := true)
builtin_initialize registerTraceClass `grind.offset.propagate

/-! Trace options for `grind` developers -/
builtin_initialize registerTraceClass `grind.debug
Expand Down
19 changes: 0 additions & 19 deletions src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,6 @@ import Lean.Meta.Tactic.Grind.Arith.Offset

namespace Lean.Meta.Grind.Arith

namespace Offset

def internalizeTerm (_e : Expr) (_a : Expr) (_k : Nat) : GoalM Unit := do
-- TODO
return ()

def internalizeCnstr (e : Expr) : GoalM Unit := do
let some c := isNatOffsetCnstr? e | return ()
let c := { c with
a := (← mkNode c.a)
b := (← mkNode c.b)
}
trace[grind.offset.internalize] "{e} ↦ {c}"
modify' fun s => { s with
cnstrs := s.cnstrs.insert { expr := e } c
}

end Offset

def internalize (e : Expr) : GoalM Unit := do
Offset.internalizeCnstr e

Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Arith/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def isCnstr? (e : Expr) : GoalM (Option (Cnstr NodeId)) :=
return (← get).arith.offset.cnstrs.find? { expr := e }

def assertTrue (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do
addEdge c.a c.b c.k (← mkOfEqTrue p)
addEdge c.u c.v c.k (← mkOfEqTrue p)

def assertFalse (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do
let p := mkOfNegEqFalse (← get').nodes c p
let c := c.neg
addEdge c.a c.b c.k p
addEdge c.u c.v c.k p

end Offset

Expand Down
136 changes: 109 additions & 27 deletions src/Lean/Meta/Tactic/Grind/Arith/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,20 @@ def mkNode (expr : Expr) : GoalM NodeId := do
}
return nodeId

private def getExpr (u : NodeId) : GoalM Expr := do
return (← get').nodes[u]!

private def getDist? (u v : NodeId) : GoalM (Option Int) := do
return (← get').targets[u]!.find? v

private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do
return (← get').proofs[u]!.find? v

partial def extractProof (u v : NodeId) : GoalM Expr := do
/--
Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the
shortest path between `u` and `v`.
-/
private partial def mkProofForPath (u v : NodeId) : GoalM Expr := do
go (← getProof? u v).get!
where
go (p : ProofInfo) : GoalM Expr := do
Expand All @@ -66,29 +73,21 @@ where
let p' := (← getProof? u p.w).get!
go (mkTrans (← get').nodes p' p v)

/--
Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t.
it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`,
this function closes the current goal by constructing a proof of `False`.
-/
private def setUnsat (u v : NodeId) (kuv : Int) (huv : Expr) (kvu : Int) : GoalM Unit := do
assert! kuv + kvu < 0
let hvu ← extractProof v u
let u := (← get').nodes[u]!
let v := (← get').nodes[v]!
if kuv == 0 then
assert! kvu < 0
closeGoal (mkApp6 (mkConst ``Grind.Nat.unsat_le_lo) u v (toExpr (-kvu).toNat) rfl_true huv hvu)
else if kvu == 0 then
assert! kuv < 0
closeGoal (mkApp6 (mkConst ``Grind.Nat.unsat_le_lo) v u (toExpr (-kuv).toNat) rfl_true hvu huv)
else if kuv < 0 then
if kvu > 0 then
closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_ro) u v (toExpr (-kuv).toNat) (toExpr kvu.toNat) rfl_true huv hvu)
else
assert! kvu < 0
closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_lo) u v (toExpr (-kuv).toNat) (toExpr (-kvu).toNat) rfl_true huv hvu)
else
assert! kuv > 0 && kvu < 0
closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_ro) v u (toExpr (-kvu).toNat) (toExpr kuv.toNat) rfl_true hvu huv)
let hvu ← mkProofForPath v u
let u ← getExpr u
let v ← getExpr v
closeGoal (mkUnsatProof u v kuv huv kvu hvu)

/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/
private def setDist (u v : NodeId) (k : Int) : GoalM Unit := do
trace[grind.offset.dist] "{({ a := u, b := v, k : Cnstr NodeId})}"
trace[grind.offset.dist] "{({ u, v, k : Cnstr NodeId})}"
modify' fun s => { s with
targets := s.targets.modify u fun es => es.insert v k
sources := s.sources.modify v fun es => es.insert u k
Expand All @@ -107,17 +106,80 @@ private def forEachSourceOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : G
private def forEachTargetOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : GoalM Unit := do
(← get').targets[u]!.forM f

/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/
private def isShorter (u v : NodeId) (k : Int) : GoalM Bool := do
if let some k' ← getDist? u v then
return k < k'
else
return true

/--
Tries to assign `e` to `True`, which is represented by constraint `c` (from `u` to `v`), using the
path `u --(k)--> v`.
-/
private def propagateTrue (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k ≤ c.k then
trace[grind.offset.propagate] "{{ u, v, k : Cnstr NodeId}} ==> {e} = True"
let kuv ← mkProofForPath u v
let u ← getExpr u
let v ← getExpr v
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv c.k
return true
return false

example (x y : Nat) : x + 2 ≤ y → ¬ (y ≤ x + 1) := by omega

/--
Tries to assign `e` to `False`, which is represented by constraint `c` (from `v` to `u`), using the
path `u --(k)--> v`.
-/
private def propagateFalse (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k + c.k < 0 then
trace[grind.offset.propagate] "{{ u, v, k : Cnstr NodeId}} ==> {e} = False"
let kuv ← mkProofForPath u v
let u ← getExpr u
let v ← getExpr v
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv c.k
return false

/--
Auxiliary function for implementing `propagateAll`.
Traverses the constraints `c` (representing an expression `e`) s.t.
`c.u = u` and `c.v = v`, it removes `c` from the list of constraints
associated with `(u, v)` IF
- `e` is already assigned, or
- `f c e` returns true
-/
@[inline]
private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → GoalM Bool) : GoalM Unit := do
if let some cs := (← get').cnstrsOf.find? (u, v) then
let cs' ← cs.filterM fun (c, e) => do
if (← isEqTrue e <||> isEqFalse e) then
return false -- constraint was already assigned
else
return !(← f c e)
modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }

/-- Performs constraint propagation. -/
private def propagateAll (u v : NodeId) (k : Int) : GoalM Unit := do
updateCnstrsOf u v fun c e => return !(← propagateTrue u v k c e)
updateCnstrsOf v u fun c e => return !(← propagateFalse u v k c e)

/--
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
`w` is the penultimate node in the path from `u` to `v`.
-/
private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : GoalM Unit := do
if (← isShorter u v k) then
setDist u v k
setProof u v (← getProof? w v).get!
propagateAll u v k

/--
Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
if no negative cycle was created, updates the shortest distance of affected
node pairs.
-/
def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
if (← isInconsistent) then return ()
if let some k' ← getDist? v u then
Expand All @@ -127,6 +189,7 @@ def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
if (← isShorter u v k) then
setDist u v k
setProof u v { w := u, k, proof := p }
propagateAll u v k
update
where
update : GoalM Unit := do
Expand All @@ -140,30 +203,49 @@ where
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
updateIfShorter i j (k₁+k+k₂) v

def internalizeCnstr (e : Expr) : GoalM Unit := do
let some c := isNatOffsetCnstr? e | return ()
let u ← mkNode c.u
let v ← mkNode c.v
let c := { c with u, v }
if let some k ← getDist? u v then
if (← propagateTrue u v k c e) then
return ()
if let some k ← getDist? v u then
if (← propagateFalse v u k c e) then
return ()
trace[grind.offset.internalize] "{e} ↦ {c}"
modify' fun s => { s with
cnstrs := s.cnstrs.insert { expr := e } c
cnstrsOf :=
let cs := if let some cs := s.cnstrsOf.find? (u, v) then (c, e) :: cs else [(c, e)]
s.cnstrsOf.insert (u, v) cs
}

def traceDists : GoalM Unit := do
let s ← get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
trace[grind.offset.dist] "#{u} -({k})-> #{v}"

def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
let a := (← get').nodes[c.a]!
let b := (← get').nodes[c.b]!
let u := (← get').nodes[c.u]!
let v := (← get').nodes[c.v]!
let mk := if c.le then mkNatLE else mkNatEq
if c.k == 0 then
return mk a b
return mk u v
else if c.k < 0 then
return mk (mkNatAdd a (Lean.toExpr ((-c.k).toNat))) b
return mk (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
else
return mk a (mkNatAdd b (Lean.toExpr c.k.toNat))
return mk u (mkNatAdd v (Lean.toExpr c.k.toNat))

def checkInvariants : GoalM Unit := do
let s ← get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
let c : Cnstr NodeId := { a := u, b := v, k }
let c : Cnstr NodeId := { u, v, k }
trace[grind.debug.offset] "{c}"
let p ← extractProof u v
let p ← mkProofForPath u v
trace[grind.debug.offset.proof] "{p} : {← inferType p}"
check p
unless (← withDefault <| isDefEq (← inferType p) (← Cnstr.toExpr c)) do
Expand Down
Loading

0 comments on commit 2421f7f

Please sign in to comment.