-
Notifications
You must be signed in to change notification settings - Fork 460
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: beta reduction in
grind
(#6700)
This PR adds support for beta reduction in the `grind` tactic. `grind` can now solve goals such as ```lean example (f : Nat → Nat) : f = (fun x : Nat => x + 5) → f 2 > 5 := by grind ```
- Loading branch information
1 parent
645bdea
commit a062eea
Showing
12 changed files
with
221 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/- | ||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Leonardo de Moura | ||
-/ | ||
prelude | ||
import Lean.Meta.Tactic.Grind.Types | ||
|
||
namespace Lean.Meta.Grind | ||
|
||
/-- Returns all lambda expressions in the equivalence class with root `root`. -/ | ||
def getEqcLambdas (root : ENode) : GoalM (Array Expr) := do | ||
unless root.hasLambdas do return #[] | ||
foldEqc root.self (init := #[]) fun n lams => | ||
if n.self.isLambda then return lams.push n.self else return lams | ||
|
||
/-- | ||
Returns the root of the functions in the equivalence class containing `e`. | ||
That is, if `f a` is in `root`s equivalence class, results contains the root of `f`. | ||
-/ | ||
def getFnRoots (e : Expr) : GoalM (Array Expr) := do | ||
foldEqc e (init := #[]) fun n fns => do | ||
let fn := n.self.getAppFn | ||
let fnRoot := (← getRoot? fn).getD fn | ||
if Option.isNone <| fns.find? (isSameExpr · fnRoot) then | ||
return fns.push fnRoot | ||
else | ||
return fns | ||
|
||
/-- | ||
For each `lam` in `lams` s.t. `lam` and `f` are in the same equivalence class, | ||
propagate `f args = lam args`. | ||
-/ | ||
def propagateBetaEqs (lams : Array Expr) (f : Expr) (args : Array Expr) : GoalM Unit := do | ||
if args.isEmpty then return () | ||
for lam in lams do | ||
let rhs := lam.beta args | ||
unless rhs.isLambda do | ||
let mut gen := Nat.max (← getGeneration lam) (← getGeneration f) | ||
let lhs := mkAppN f args | ||
if (← hasSameType f lam) then | ||
let mut h ← mkEqProof f lam | ||
for arg in args do | ||
gen := Nat.max gen (← getGeneration arg) | ||
h ← mkCongrFun h arg | ||
let eq ← mkEq lhs rhs | ||
trace[grind.beta] "{eq}, using {lam}" | ||
addNewFact h eq (gen+1) | ||
|
||
private def isPropagateBetaTarget (e : Expr) : GoalM Bool := do | ||
let .app f _ := e | return false | ||
go f | ||
where | ||
go (f : Expr) : GoalM Bool := do | ||
if let some root ← getRootENode? f then | ||
return root.hasLambdas | ||
let .app f _ := f | return false | ||
go f | ||
|
||
/-- | ||
Applies beta-reduction for lambdas in `f`s equivalence class. | ||
We use this function while internalizing new applications. | ||
-/ | ||
def propagateBetaForNewApp (e : Expr) : GoalM Unit := do | ||
unless (← isPropagateBetaTarget e) do return () | ||
let mut e := e | ||
let mut args := #[] | ||
repeat | ||
unless args.isEmpty do | ||
if let some root ← getRootENode? e then | ||
if root.hasLambdas then | ||
propagateBetaEqs (← getEqcLambdas root) e args.reverse | ||
let .app f arg := e | return () | ||
e := f | ||
args := args.push arg | ||
|
||
end Lean.Meta.Grind |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
def f (x : Nat) : Nat → Nat → Nat := | ||
match x with | ||
| 0 => fun _ _ => 0 | ||
| _+1 => fun a b => a + b | ||
|
||
example : f 0 b c = 0 := by | ||
grind [f] | ||
|
||
example : f (a+1) b c = b + c := by | ||
grind [f] | ||
|
||
example : f x b c ≠ b + c → x = a + 1 → False := by | ||
grind [f] | ||
|
||
example : x = a + 1 → f x b c ≠ b + c → False := by | ||
grind [f] | ||
|
||
example : x = a + 1 → f x b c ≠ b + c → False := by | ||
grind [f] | ||
|
||
example : f x b c > 0 → x = 0 → False := by | ||
grind [f] | ||
|
||
example : f x b c > 0 → x ≠ 0 := by | ||
grind [f] | ||
|
||
example (f : Nat → Nat → Nat) : f 2 3 ≠ 5 → f = (fun x y : Nat => x + y) → False := by | ||
grind | ||
|
||
opaque bla : Nat → Nat → Nat → Nat | ||
|
||
/-- | ||
info: [grind.beta] f 2 3 = bla 2 3 2, using fun x y => bla x y x | ||
[grind.beta] f 2 3 = 2 + 3, using fun x y => x + y | ||
-/ | ||
#guard_msgs (info) in | ||
set_option trace.grind.beta true in | ||
example (g h f : Nat → Nat → Nat) : | ||
f 2 3 ≠ 5 → | ||
g = (fun x y : Nat => x + y) → | ||
h = (fun x y => bla x y x) → | ||
g = h → | ||
f = h → | ||
False := by | ||
grind | ||
|
||
example (g h f : Nat → Nat → Nat) : | ||
f 2 3 ≠ 5 → | ||
h = (fun x y => bla x y x) → | ||
g = (fun x y : Nat => x + y) → | ||
g = h → | ||
h = f → | ||
False := by | ||
grind | ||
|
||
|
||
example (f : Nat → Nat → Nat) : f = (fun x y : Nat => x + y) → f 2 3 = 5 := by | ||
grind | ||
|
||
example (f g h : Nat → Nat → Nat) : | ||
h = (fun x y => bla x y x) → | ||
g = (fun x y : Nat => x + y) → | ||
g = h → | ||
h = f → | ||
f 2 3 = 5 := by | ||
grind | ||
|
||
example (f : Nat → Nat) : f = (fun x : Nat => x + 5) → f 2 > 5 := by | ||
grind | ||
|
||
example (f : Nat → Nat → Nat) : f a = (fun x : Nat => x + 5) → f a 2 > 5 := by | ||
grind |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.