Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: exhaustive offset constraint propagation in the grind tactic #6618

Merged
merged 10 commits into from
Jan 13, 2025
35 changes: 34 additions & 1 deletion src/Init/Grind/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import Init.Core
import Init.Omega

namespace Lean.Grind
def isLt (x y : Nat) : Bool := x < y
abbrev isLt (x y : Nat) : Bool := x < y
abbrev isLE (x y : Nat) : Bool := x ≤ y

/-! Theorems for transitivity. -/
theorem Nat.le_ro (u w v k : Nat) : u ≤ w → w ≤ v + k → u ≤ v + k := by
omega
theorem Nat.le_lo (u w v k : Nat) : u ≤ w → w + k ≤ v → u + k ≤ v := by
Expand All @@ -31,6 +33,7 @@ theorem Nat.ro_lo_2 (u w v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ w
theorem Nat.ro_ro (u w v k₁ k₂ : Nat) : u ≤ w + k₁ → w ≤ v + k₂ → u ≤ v + (k₁ + k₂) := by
omega

/-! Theorems for negating constraints. -/
theorem Nat.of_le_eq_false (u v : Nat) : ((u ≤ v) = False) → v + 1 ≤ u := by
simp; omega
theorem Nat.of_lo_eq_false_1 (u v : Nat) : ((u + 1 ≤ v) = False) → v ≤ u := by
Expand All @@ -40,11 +43,41 @@ theorem Nat.of_lo_eq_false (u v k : Nat) : ((u + k ≤ v) = False) → v ≤ u +
theorem Nat.of_ro_eq_false (u v k : Nat) : ((u ≤ v + k) = False) → v + (k+1) ≤ u := by
simp; omega

/-! Theorems for closing a goal. -/
theorem Nat.unsat_le_lo (u v k : Nat) : isLt 0 k = true → u ≤ v → v + k ≤ u → False := by
simp [isLt]; omega
theorem Nat.unsat_lo_lo (u v k₁ k₂ : Nat) : isLt 0 (k₁+k₂) = true → u + k₁ ≤ v → v + k₂ ≤ u → False := by
simp [isLt]; omega
theorem Nat.unsat_lo_ro (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true → u + k₁ ≤ v → v ≤ u + k₂ → False := by
simp [isLt]; omega

/-! Theorems for propagating constraints to `True` -/
theorem Nat.lo_eq_true_of_lo (u v k₁ k₂ : Nat) : isLE k₂ k₁ = true → u + k₁ ≤ v → (u + k₂ ≤ v) = True :=
by simp [isLt]; omega
theorem Nat.le_eq_true_of_lo (u v k : Nat) : u + k ≤ v → (u ≤ v) = True :=
by simp; omega
theorem Nat.le_eq_true_of_le (u v : Nat) : u ≤ v → (u ≤ v) = True :=
by simp
theorem Nat.ro_eq_true_of_lo (u v k₁ k₂ : Nat) : u + k₁ ≤ v → (u ≤ v + k₂) = True :=
by simp; omega
theorem Nat.ro_eq_true_of_le (u v k : Nat) : u ≤ v → (u ≤ v + k) = True :=
by simp; omega
theorem Nat.ro_eq_true_of_ro (u v k₁ k₂ : Nat) : isLE k₁ k₂ = true → u ≤ v + k₁ → (u ≤ v + k₂) = True :=
by simp [isLE]; omega

/-!
Theorems for propagating constraints to `False`.
They are variants of the theorems for closing a goal.
-/
theorem Nat.lo_eq_false_of_le (u v k : Nat) : isLt 0 k = true → u ≤ v → (v + k ≤ u) = False := by
simp [isLt]; omega
theorem Nat.le_eq_false_of_lo (u v k : Nat) : isLt 0 k = true → u + k ≤ v → (v ≤ u) = False := by
simp [isLt]; omega
theorem Nat.lo_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt 0 (k₁+k₂) = true → u + k₁ ≤ v → (v + k₂ ≤ u) = False := by
simp [isLt]; omega
theorem Nat.ro_eq_false_of_lo (u v k₁ k₂ : Nat) : isLt k₂ k₁ = true → u + k₁ ≤ v → (v ≤ u + k₂) = False := by
simp [isLt]; omega
theorem Nat.lo_eq_false_of_ro (u v k₁ k₂ : Nat) : isLt k₁ k₂ = true → u ≤ v + k₁ → (v + k₂ ≤ u) = False := by
simp [isLt]; omega

end Lean.Grind
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ builtin_initialize registerTraceClass `grind.offset
builtin_initialize registerTraceClass `grind.offset.dist
builtin_initialize registerTraceClass `grind.offset.internalize
builtin_initialize registerTraceClass `grind.offset.internalize.term (inherited := true)
builtin_initialize registerTraceClass `grind.offset.propagate

/-! Trace options for `grind` developers -/
builtin_initialize registerTraceClass `grind.debug
Expand Down
19 changes: 0 additions & 19 deletions src/Lean/Meta/Tactic/Grind/Arith/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,6 @@ import Lean.Meta.Tactic.Grind.Arith.Offset

namespace Lean.Meta.Grind.Arith

namespace Offset

def internalizeTerm (_e : Expr) (_a : Expr) (_k : Nat) : GoalM Unit := do
-- TODO
return ()

def internalizeCnstr (e : Expr) : GoalM Unit := do
let some c := isNatOffsetCnstr? e | return ()
let c := { c with
a := (← mkNode c.a)
b := (← mkNode c.b)
}
trace[grind.offset.internalize] "{e} ↦ {c}"
modify' fun s => { s with
cnstrs := s.cnstrs.insert { expr := e } c
}

end Offset

def internalize (e : Expr) : GoalM Unit := do
Offset.internalizeCnstr e

Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Arith/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def isCnstr? (e : Expr) : GoalM (Option (Cnstr NodeId)) :=
return (← get).arith.offset.cnstrs.find? { expr := e }

def assertTrue (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do
addEdge c.a c.b c.k (← mkOfEqTrue p)
addEdge c.u c.v c.k (← mkOfEqTrue p)

def assertFalse (c : Cnstr NodeId) (p : Expr) : GoalM Unit := do
let p := mkOfNegEqFalse (← get').nodes c p
let c := c.neg
addEdge c.a c.b c.k p
addEdge c.u c.v c.k p

end Offset

Expand Down
136 changes: 109 additions & 27 deletions src/Lean/Meta/Tactic/Grind/Arith/Offset.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,20 @@ def mkNode (expr : Expr) : GoalM NodeId := do
}
return nodeId

private def getExpr (u : NodeId) : GoalM Expr := do
return (← get').nodes[u]!

private def getDist? (u v : NodeId) : GoalM (Option Int) := do
return (← get').targets[u]!.find? v

private def getProof? (u v : NodeId) : GoalM (Option ProofInfo) := do
return (← get').proofs[u]!.find? v

partial def extractProof (u v : NodeId) : GoalM Expr := do
/--
Returns a proof for `u + k ≤ v` (or `u ≤ v + k`) where `k` is the
shortest path between `u` and `v`.
-/
private partial def mkProofForPath (u v : NodeId) : GoalM Expr := do
go (← getProof? u v).get!
where
go (p : ProofInfo) : GoalM Expr := do
Expand All @@ -66,29 +73,21 @@ where
let p' := (← getProof? u p.w).get!
go (mkTrans (← get').nodes p' p v)

/--
Given a new edge edge `u --(kuv)--> v` justified by proof `huv` s.t.
it creates a negative cycle with the existing path `v --{kvu}-->* u`, i.e., `kuv + kvu < 0`,
this function closes the current goal by constructing a proof of `False`.
-/
private def setUnsat (u v : NodeId) (kuv : Int) (huv : Expr) (kvu : Int) : GoalM Unit := do
assert! kuv + kvu < 0
let hvu ← extractProof v u
let u := (← get').nodes[u]!
let v := (← get').nodes[v]!
if kuv == 0 then
assert! kvu < 0
closeGoal (mkApp6 (mkConst ``Grind.Nat.unsat_le_lo) u v (toExpr (-kvu).toNat) rfl_true huv hvu)
else if kvu == 0 then
assert! kuv < 0
closeGoal (mkApp6 (mkConst ``Grind.Nat.unsat_le_lo) v u (toExpr (-kuv).toNat) rfl_true hvu huv)
else if kuv < 0 then
if kvu > 0 then
closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_ro) u v (toExpr (-kuv).toNat) (toExpr kvu.toNat) rfl_true huv hvu)
else
assert! kvu < 0
closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_lo) u v (toExpr (-kuv).toNat) (toExpr (-kvu).toNat) rfl_true huv hvu)
else
assert! kuv > 0 && kvu < 0
closeGoal (mkApp7 (mkConst ``Grind.Nat.unsat_lo_ro) v u (toExpr (-kvu).toNat) (toExpr kuv.toNat) rfl_true hvu huv)
let hvu ← mkProofForPath v u
let u ← getExpr u
let v ← getExpr v
closeGoal (mkUnsatProof u v kuv huv kvu hvu)

/-- Sets the new shortest distance `k` between nodes `u` and `v`. -/
private def setDist (u v : NodeId) (k : Int) : GoalM Unit := do
trace[grind.offset.dist] "{({ a := u, b := v, k : Cnstr NodeId})}"
trace[grind.offset.dist] "{({ u, v, k : Cnstr NodeId})}"
modify' fun s => { s with
targets := s.targets.modify u fun es => es.insert v k
sources := s.sources.modify v fun es => es.insert u k
Expand All @@ -107,17 +106,80 @@ private def forEachSourceOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : G
private def forEachTargetOf (u : NodeId) (f : NodeId → Int → GoalM Unit) : GoalM Unit := do
(← get').targets[u]!.forM f

/-- Returns `true` if `k` is smaller than the shortest distance between `u` and `v` -/
private def isShorter (u v : NodeId) (k : Int) : GoalM Bool := do
if let some k' ← getDist? u v then
return k < k'
else
return true

/--
Tries to assign `e` to `True`, which is represented by constraint `c` (from `u` to `v`), using the
path `u --(k)--> v`.
-/
private def propagateTrue (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k ≤ c.k then
trace[grind.offset.propagate] "{{ u, v, k : Cnstr NodeId}} ==> {e} = True"
let kuv ← mkProofForPath u v
let u ← getExpr u
let v ← getExpr v
pushEqTrue e <| mkPropagateEqTrueProof u v k kuv c.k
return true
return false

example (x y : Nat) : x + 2 ≤ y → ¬ (y ≤ x + 1) := by omega

/--
Tries to assign `e` to `False`, which is represented by constraint `c` (from `v` to `u`), using the
path `u --(k)--> v`.
-/
private def propagateFalse (u v : NodeId) (k : Int) (c : Cnstr NodeId) (e : Expr) : GoalM Bool := do
if k + c.k < 0 then
trace[grind.offset.propagate] "{{ u, v, k : Cnstr NodeId}} ==> {e} = False"
let kuv ← mkProofForPath u v
let u ← getExpr u
let v ← getExpr v
pushEqFalse e <| mkPropagateEqFalseProof u v k kuv c.k
return false

/--
Auxiliary function for implementing `propagateAll`.
Traverses the constraints `c` (representing an expression `e`) s.t.
`c.u = u` and `c.v = v`, it removes `c` from the list of constraints
associated with `(u, v)` IF
- `e` is already assigned, or
- `f c e` returns true
-/
@[inline]
private def updateCnstrsOf (u v : NodeId) (f : Cnstr NodeId → Expr → GoalM Bool) : GoalM Unit := do
if let some cs := (← get').cnstrsOf.find? (u, v) then
let cs' ← cs.filterM fun (c, e) => do
if (← isEqTrue e <||> isEqFalse e) then
return false -- constraint was already assigned
else
return !(← f c e)
modify' fun s => { s with cnstrsOf := s.cnstrsOf.insert (u, v) cs' }

/-- Performs constraint propagation. -/
private def propagateAll (u v : NodeId) (k : Int) : GoalM Unit := do
updateCnstrsOf u v fun c e => return !(← propagateTrue u v k c e)
updateCnstrsOf v u fun c e => return !(← propagateFalse u v k c e)

/--
If `isShorter u v k`, updates the shortest distance between `u` and `v`.
`w` is the penultimate node in the path from `u` to `v`.
-/
private def updateIfShorter (u v : NodeId) (k : Int) (w : NodeId) : GoalM Unit := do
if (← isShorter u v k) then
setDist u v k
setProof u v (← getProof? w v).get!
propagateAll u v k

/--
Adds an edge `u --(k) --> v` justified by the proof term `p`, and then
if no negative cycle was created, updates the shortest distance of affected
node pairs.
-/
def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
if (← isInconsistent) then return ()
if let some k' ← getDist? v u then
Expand All @@ -127,6 +189,7 @@ def addEdge (u : NodeId) (v : NodeId) (k : Int) (p : Expr) : GoalM Unit := do
if (← isShorter u v k) then
setDist u v k
setProof u v { w := u, k, proof := p }
propagateAll u v k
update
where
update : GoalM Unit := do
Expand All @@ -140,30 +203,49 @@ where
/- Check whether new path: `i -(k₁)-> u -(k)-> v -(k₂) -> j` is shorter -/
updateIfShorter i j (k₁+k+k₂) v

def internalizeCnstr (e : Expr) : GoalM Unit := do
let some c := isNatOffsetCnstr? e | return ()
let u ← mkNode c.u
let v ← mkNode c.v
let c := { c with u, v }
if let some k ← getDist? u v then
if (← propagateTrue u v k c e) then
return ()
if let some k ← getDist? v u then
if (← propagateFalse v u k c e) then
return ()
trace[grind.offset.internalize] "{e} ↦ {c}"
modify' fun s => { s with
cnstrs := s.cnstrs.insert { expr := e } c
cnstrsOf :=
let cs := if let some cs := s.cnstrsOf.find? (u, v) then (c, e) :: cs else [(c, e)]
s.cnstrsOf.insert (u, v) cs
}

def traceDists : GoalM Unit := do
let s ← get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
trace[grind.offset.dist] "#{u} -({k})-> #{v}"

def Cnstr.toExpr (c : Cnstr NodeId) : GoalM Expr := do
let a := (← get').nodes[c.a]!
let b := (← get').nodes[c.b]!
let u := (← get').nodes[c.u]!
let v := (← get').nodes[c.v]!
let mk := if c.le then mkNatLE else mkNatEq
if c.k == 0 then
return mk a b
return mk u v
else if c.k < 0 then
return mk (mkNatAdd a (Lean.toExpr ((-c.k).toNat))) b
return mk (mkNatAdd u (Lean.toExpr ((-c.k).toNat))) v
else
return mk a (mkNatAdd b (Lean.toExpr c.k.toNat))
return mk u (mkNatAdd v (Lean.toExpr c.k.toNat))

def checkInvariants : GoalM Unit := do
let s ← get'
for u in [:s.targets.size], es in s.targets.toArray do
for (v, k) in es do
let c : Cnstr NodeId := { a := u, b := v, k }
let c : Cnstr NodeId := { u, v, k }
trace[grind.debug.offset] "{c}"
let p ← extractProof u v
let p ← mkProofForPath u v
trace[grind.debug.offset.proof] "{p} : {← inferType p}"
check p
unless (← withDefault <| isDefEq (← inferType p) (← Cnstr.toExpr c)) do
Expand Down
Loading
Loading