Skip to content

Commit

Permalink
feat: beta reduction in grind (#6700)
Browse files Browse the repository at this point in the history
This PR adds support for beta reduction in the `grind` tactic. `grind`
can now solve goals such as
```lean
example (f : Nat → Nat) : f = (fun x : Nat => x + 5) → f 2 > 5 := by
  grind
```
  • Loading branch information
leodemoura authored Jan 19, 2025
1 parent 645bdea commit a062eea
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 20 deletions.
2 changes: 2 additions & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ builtin_initialize registerTraceClass `grind.offset.propagate
builtin_initialize registerTraceClass `grind.offset.eq
builtin_initialize registerTraceClass `grind.offset.eq.to (inherited := true)
builtin_initialize registerTraceClass `grind.offset.eq.from (inherited := true)
builtin_initialize registerTraceClass `grind.beta

/-! Trace options for `grind` developers -/
builtin_initialize registerTraceClass `grind.debug
Expand All @@ -68,5 +69,6 @@ builtin_initialize registerTraceClass `grind.debug.canon
builtin_initialize registerTraceClass `grind.debug.offset
builtin_initialize registerTraceClass `grind.debug.offset.proof
builtin_initialize registerTraceClass `grind.debug.ematch.pattern
builtin_initialize registerTraceClass `grind.debug.beta

end Lean
77 changes: 77 additions & 0 deletions src/Lean/Meta/Tactic/Grind/Beta.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Meta.Tactic.Grind.Types

namespace Lean.Meta.Grind

/-- Returns all lambda expressions in the equivalence class with root `root`. -/
def getEqcLambdas (root : ENode) : GoalM (Array Expr) := do
unless root.hasLambdas do return #[]
foldEqc root.self (init := #[]) fun n lams =>
if n.self.isLambda then return lams.push n.self else return lams

/--
Returns the root of the functions in the equivalence class containing `e`.
That is, if `f a` is in `root`s equivalence class, results contains the root of `f`.
-/
def getFnRoots (e : Expr) : GoalM (Array Expr) := do
foldEqc e (init := #[]) fun n fns => do
let fn := n.self.getAppFn
let fnRoot := (← getRoot? fn).getD fn
if Option.isNone <| fns.find? (isSameExpr · fnRoot) then
return fns.push fnRoot
else
return fns

/--
For each `lam` in `lams` s.t. `lam` and `f` are in the same equivalence class,
propagate `f args = lam args`.
-/
def propagateBetaEqs (lams : Array Expr) (f : Expr) (args : Array Expr) : GoalM Unit := do
if args.isEmpty then return ()
for lam in lams do
let rhs := lam.beta args
unless rhs.isLambda do
let mut gen := Nat.max (← getGeneration lam) (← getGeneration f)
let lhs := mkAppN f args
if (← hasSameType f lam) then
let mut h ← mkEqProof f lam
for arg in args do
gen := Nat.max gen (← getGeneration arg)
h ← mkCongrFun h arg
let eq ← mkEq lhs rhs
trace[grind.beta] "{eq}, using {lam}"
addNewFact h eq (gen+1)

private def isPropagateBetaTarget (e : Expr) : GoalM Bool := do
let .app f _ := e | return false
go f
where
go (f : Expr) : GoalM Bool := do
if let some root ← getRootENode? f then
return root.hasLambdas
let .app f _ := f | return false
go f

/--
Applies beta-reduction for lambdas in `f`s equivalence class.
We use this function while internalizing new applications.
-/
def propagateBetaForNewApp (e : Expr) : GoalM Unit := do
unless (← isPropagateBetaTarget e) do return ()
let mut e := e
let mut args := #[]
repeat
unless args.isEmpty do
if let some root ← getRootENode? e then
if root.hasLambdas then
propagateBetaEqs (← getEqcLambdas root) e args.reverse
let .app f arg := e | return ()
e := f
args := args.push arg

end Lean.Meta.Grind
35 changes: 34 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Lean.Meta.Tactic.Grind.Inv
import Lean.Meta.Tactic.Grind.PP
import Lean.Meta.Tactic.Grind.Ctor
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Beta
import Lean.Meta.Tactic.Grind.Internalize

namespace Lean.Meta.Grind
Expand Down Expand Up @@ -40,7 +41,7 @@ Remove `root` parents from the congruence table.
This is an auxiliary function performed while merging equivalence classes.
-/
private def removeParents (root : Expr) : GoalM ParentSet := do
let parents ← getParentsAndReset root
let parents ← getParents root
for parent in parents do
-- Recall that we may have `Expr.forallE` in `parents` because of `ForallProp.lean`
if (← pure parent.isApp <&&> isCongrRoot parent) then
Expand Down Expand Up @@ -107,6 +108,31 @@ private def propagateOffsetEq (rhsRoot lhsRoot : ENode) : GoalM Unit := do
if let some rhsOffset := rhsRoot.offset? then
Arith.processNewOffsetEqLit rhsOffset lhsRoot.self

/--
Tries to apply beta-reductiong using the parent applications of the functions in `fns` with
the lambda expressions in `lams`.
-/
def propagateBeta (lams : Array Expr) (fns : Array Expr) : GoalM Unit := do
if lams.isEmpty then return ()
let lamRoot ← getRoot lams.back!
trace[grind.debug.beta] "fns: {fns}, lams: {lams}"
for fn in fns do
trace[grind.debug.beta] "fn: {fn}, parents: {(← getParents fn).toArray}"
for parent in (← getParents fn) do
let mut args := #[]
let mut curr := parent
trace[grind.debug.beta] "parent: {parent}"
repeat
trace[grind.debug.beta] "curr: {curr}"
if (← isEqv curr lamRoot) then
propagateBetaEqs lams curr args.reverse
let .app f arg := curr
| break
-- Remark: recall that we do not eagerly internalize partial applications.
internalize curr (← getGeneration parent)
args := args.push arg
curr := f

private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
let lhsNode ← getENode lhs
let rhsNode ← getENode rhs
Expand Down Expand Up @@ -158,6 +184,10 @@ where
proof? := proof
flipped
}
let lams₁ ← getEqcLambdas lhsRoot
let lams₂ ← getEqcLambdas rhsRoot
let fns₁ ← if lams₁.isEmpty then pure #[] else getFnRoots rhsRoot.self
let fns₂ ← if lams₂.isEmpty then pure #[] else getFnRoots lhsRoot.self
let parents ← removeParents lhsRoot.self
updateRoots lhs rhsNode.root
trace_goal[grind.debug] "{← ppENodeRef lhs} new root {← ppENodeRef rhsNode.root}, {← ppENodeRef (← getRoot lhs)}"
Expand All @@ -172,6 +202,9 @@ where
hasLambdas := rhsRoot.hasLambdas || lhsRoot.hasLambdas
heqProofs := isHEq || rhsRoot.heqProofs || lhsRoot.heqProofs
}
propagateBeta lams₁ fns₁
propagateBeta lams₂ fns₂
resetParentsOf lhsRoot.self
copyParentsTo parents rhsNode.root
unless (← isInconsistent) do
updateMT rhsRoot.self
Expand Down
5 changes: 4 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Internalize.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import Lean.Meta.Match.MatchEqsExt
import Lean.Meta.Tactic.Grind.Types
import Lean.Meta.Tactic.Grind.Util
import Lean.Meta.Tactic.Grind.Canon
import Lean.Meta.Tactic.Grind.Beta
import Lean.Meta.Tactic.Grind.Arith.Internalize

namespace Lean.Meta.Grind
Expand Down Expand Up @@ -194,7 +195,7 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr :=
activateTheoremPatterns fName generation
else
internalize f generation e
registerParent e f
registerParent e f
for h : i in [: args.size] do
let arg := args[i]
internalize arg generation e
Expand All @@ -204,6 +205,8 @@ partial def internalize (e : Expr) (generation : Nat) (parent? : Option Expr :=
updateAppMap e
Arith.internalize e parent?
propagateUp e
propagateBetaForNewApp e

end

end Lean.Meta.Grind
34 changes: 25 additions & 9 deletions src/Lean/Meta/Tactic/Grind/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,9 @@ def getENode (e : Expr) : GoalM ENode := do
(← get).getENode e

/-- Returns the generation of the given term. Is assumes it has been internalized -/
def getGeneration (e : Expr) : GoalM Nat :=
return (← getENode e).generation
def getGeneration (e : Expr) : GoalM Nat := do
let some n ← getENode? e | return 0
return n.generation

/-- Returns `true` if `e` is in the equivalence class of `True`. -/
def isEqTrue (e : Expr) : GoalM Bool := do
Expand All @@ -519,8 +520,8 @@ def isEqv (a b : Expr) : GoalM Bool := do
if isSameExpr a b then
return true
else
let na ← getENode a
let nb ← getENode b
let some na ← getENode? a | return false
let some nb ← getENode? b | return false
return isSameExpr na.root nb.root

/-- Returns `true` if the root of its equivalence class. -/
Expand Down Expand Up @@ -549,6 +550,11 @@ def getRoot (e : Expr) : GoalM Expr := do
def getRootENode (e : Expr) : GoalM ENode := do
getENode (← getRoot e)

/-- Returns the root enode in the equivalence class of `e` if it is in an equivalence class. -/
def getRootENode? (e : Expr) : GoalM (Option ENode) := do
let some n ← getENode? e | return none
getENode? n.root

/--
Returns the next element in the equivalence class of `e`
if `e` has been internalized in the given goal.
Expand Down Expand Up @@ -614,7 +620,7 @@ Records that `parent` is a parent of `child`. This function actually stores the
information in the root (aka canonical representative) of `child`.
-/
def registerParent (parent : Expr) (child : Expr) : GoalM Unit := do
let some childRoot ← getRoot? child | return ()
let childRoot := (← getRoot? child).getD child
let parents := if let some parents := (← get).parents.find? { expr := childRoot } then parents else {}
modify fun s => { s with parents := s.parents.insert { expr := childRoot } (parents.insert parent) }

Expand All @@ -628,12 +634,10 @@ def getParents (e : Expr) : GoalM ParentSet := do
return parents

/--
Similar to `getParents`, but also removes the entry `e ↦ parents` from the parent map.
Removes the entry `e ↦ parents` from the parent map.
-/
def getParentsAndReset (e : Expr) : GoalM ParentSet := do
let parents ← getParents e
def resetParentsOf (e : Expr) : GoalM Unit := do
modify fun s => { s with parents := s.parents.erase { expr := e } }
return parents

/--
Copy `parents` to the parents of `root`.
Expand Down Expand Up @@ -800,6 +804,18 @@ def getENodes : GoalM (Array ENode) := do
if isSameExpr n.next e then return ()
curr := n.next

/-- Folds using `f` and `init` over the equivalence class containing `e` -/
@[inline] def foldEqc (e : Expr) (init : α) (f : ENode → α → GoalM α) : GoalM α := do
let mut curr := e
let mut r := init
repeat
let n ← getENode curr
r ← f n r
if isSameExpr n.next e then return r
curr := n.next
unreachable!
return r

def forEachENode (f : ENode → GoalM Unit) : GoalM Unit := do
let nodes ← getENodes
for n in nodes do
Expand Down
72 changes: 72 additions & 0 deletions tests/lean/run/grind_beta.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
def f (x : Nat) : Nat → Nat → Nat :=
match x with
| 0 => fun _ _ => 0
| _+1 => fun a b => a + b

example : f 0 b c = 0 := by
grind [f]

example : f (a+1) b c = b + c := by
grind [f]

example : f x b c ≠ b + c → x = a + 1 → False := by
grind [f]

example : x = a + 1 → f x b c ≠ b + c → False := by
grind [f]

example : x = a + 1 → f x b c ≠ b + c → False := by
grind [f]

example : f x b c > 0 → x = 0 → False := by
grind [f]

example : f x b c > 0 → x ≠ 0 := by
grind [f]

example (f : Nat → Nat → Nat) : f 2 35 → f = (fun x y : Nat => x + y) → False := by
grind

opaque bla : Nat → Nat → Nat → Nat

/--
info: [grind.beta] f 2 3 = bla 2 3 2, using fun x y => bla x y x
[grind.beta] f 2 3 = 2 + 3, using fun x y => x + y
-/
#guard_msgs (info) in
set_option trace.grind.beta true in
example (g h f : Nat → Nat → Nat) :
f 2 35
g = (fun x y : Nat => x + y) →
h = (fun x y => bla x y x) →
g = h →
f = h →
False := by
grind

example (g h f : Nat → Nat → Nat) :
f 2 35
h = (fun x y => bla x y x) →
g = (fun x y : Nat => x + y) →
g = h →
h = f →
False := by
grind


example (f : Nat → Nat → Nat) : f = (fun x y : Nat => x + y) → f 2 3 = 5 := by
grind

example (f g h : Nat → Nat → Nat) :
h = (fun x y => bla x y x) →
g = (fun x y : Nat => x + y) →
g = h →
h = f →
f 2 3 = 5 := by
grind

example (f : Nat → Nat) : f = (fun x : Nat => x + 5) → f 2 > 5 := by
grind

example (f : Nat → Nat → Nat) : f a = (fun x : Nat => x + 5) → f a 2 > 5 := by
grind
6 changes: 2 additions & 4 deletions tests/lean/run/grind_canon_insts.lean
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,13 @@ theorem left_comm [CommMonoid α] (a b c : α) : a * (b * c) = b * (a * c) := by

open Lean Meta Elab Tactic Grind in
def fallback : Fallback := do
let nodes ← filterENodes fun e => return e.self.isAppOf ``HMul.hMul
let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``HMul.hMul
trace[Meta.debug] "{nodes.toList.map (·.self)}"
(← get).mvarId.admit

set_option trace.Meta.debug true

/--
info: [Meta.debug] [b * c, a * (b * c), d * (b * c)]
-/
/-- info: [Meta.debug] [b * c, a * (b * c), d * (b * c)] -/
#guard_msgs (info) in
example (a b c d : Nat) : b * (a * c) = d * (b * c) → False := by
rw [left_comm] -- Introduces a new (non-canonical) instance for `Mul Nat`
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/grind_canon_types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def f (a : α) := a

open Lean Meta Grind in
def fallback : Fallback := do
let nodes ← filterENodes fun e => return e.self.isAppOf ``f
let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``f
trace[Meta.debug] "{nodes.toList.map (·.self)}"
(← get).mvarId.admit

Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/grind_congr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def g (a : Nat) := a + a
-- Prints the equivalence class containing a `f` application
open Lean Meta Grind in
def fallback : Fallback := do
let #[n, _] ← filterENodes fun e => return e.self.isAppOf ``f | unreachable!
let #[n, _] ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``f | unreachable!
let eqc ← getEqc n.self
trace[Meta.debug] eqc
(← get).mvarId.admit
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/grind_many_eqs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def fallback (n : Nat) : Fallback := do
-- The `f 0` equivalence class contains `n+1` elements
assert! (← getEqc f0).length == n + 1
forEachENode fun node => do
if node.self.isAppOf ``g then
if node.self.isApp && node.self.isAppOf ``g then
-- Any equivalence class containing a `g`-application contains 2 elements
assert! (← getEqc (← getRoot node.self)).length == 2
(← get).mvarId.admit
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/grind_nested_proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ open Lean Meta Grind in
def fallback : Fallback := do
let nodes ← filterENodes fun e => return e.self.isAppOf ``Lean.Grind.nestedProof
trace[Meta.debug] "{nodes.toList.map (·.self)}"
let nodes ← filterENodes fun e => return e.self.isAppOf ``GetElem.getElem
let nodes ← filterENodes fun e => return e.self.isApp && e.self.isAppOf ``GetElem.getElem
let [_, n, _] := nodes.toList | unreachable!
trace[Meta.debug] "{← getEqc n.self}"
(← get).mvarId.admit
Expand Down
Loading

0 comments on commit a062eea

Please sign in to comment.