Skip to content

Commit

Permalink
feat: add term offset support to the grind E-matching modulo (#6533)
Browse files Browse the repository at this point in the history
This PR adds support to E-matching offset patterns. For example, we want
to be able to E-match the pattern `f (#0 + 1)` with term `f (a + 2)`.
  • Loading branch information
leodemoura authored Jan 5, 2025
1 parent 9dcbc33 commit dc5c809
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 19 deletions.
3 changes: 3 additions & 0 deletions src/Init/Grind/Norm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,7 @@ attribute [grind_norm] Nat.le_zero_eq
-- GT GE
attribute [grind_norm] GT.gt GE.ge

-- Succ
attribute [grind_norm] Nat.succ_eq_add_one

end Lean.Grind
3 changes: 3 additions & 0 deletions src/Init/Grind/Util.lean
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ We use it when adding instances of `match`-equations to prevent them from being
-/
def doNotSimp {α : Sort u} (a : α) : α := a

/-- Gadget for representing offsets `t+k` in patterns. -/
def offset (a b : Nat) : Nat := a + b

set_option pp.proofs true

theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (nestedProof p hp) (nestedProof q hq) := by
Expand Down
6 changes: 4 additions & 2 deletions src/Lean/Elab/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def elabGrindPattern : CommandElab := fun stx => do
let info ← getConstInfo declName
forallTelescope info.type fun xs _ => do
let patterns ← terms.getElems.mapM fun term => do
let pattern ← instantiateMVars (← elabTerm term none)
let pattern ← Grind.unfoldReducible pattern
let pattern ← elabTerm term none
synthesizeSyntheticMVarsUsingDefault
let pattern ← instantiateMVars pattern
let pattern ← Grind.preprocessPattern pattern
return pattern.abstract xs
Grind.addEMatchTheorem declName xs.size patterns.toList
| _ => throwUnsupportedSyntax
Expand Down
73 changes: 60 additions & 13 deletions src/Lean/Meta/Tactic/Grind/EMatch.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ namespace EMatch
inductive Cnstr where
| /-- Matches pattern `pat` with term `e` -/
«match» (pat : Expr) (e : Expr)
| /-- Matches offset pattern `pat+k` with term `e` -/
offset (pat : Expr) (k : Nat) (e : Expr)
| /-- This constraint is used to encode multi-patterns. -/
«continue» (pat : Expr)
deriving Inhabited
Expand Down Expand Up @@ -88,6 +90,28 @@ private def eqvFunctions (pFn eFn : Expr) : Bool :=
(pFn.isFVar && pFn == eFn)
|| (pFn.isConst && eFn.isConstOf pFn.constName!)

/-- Matches a pattern argument. See `matchArgs?`. -/
private def matchArg? (c : Choice) (pArg : Expr) (eArg : Expr) : OptionT GoalM Choice := do
if isPatternDontCare pArg then
return c
else if pArg.isBVar then
assign? c pArg.bvarIdx! eArg
else if let some pArg := groundPattern? pArg then
guard (← isEqv pArg eArg)
return c
else if let some (pArg, k) := isOffsetPattern? pArg then
assert! Option.isNone <| isOffsetPattern? pArg
assert! !isPatternDontCare pArg
return { c with cnstrs := .offset pArg k eArg :: c.cnstrs }
else
return { c with cnstrs := .match pArg eArg :: c.cnstrs }

private def Choice.updateGen (c : Choice) (gen : Nat) : Choice :=
{ c with gen := Nat.max gen c.gen }

private def pushChoice (c : Choice) : M Unit :=
modify fun s => { s with choiceStack := c :: s.choiceStack }

/--
Matches arguments of pattern `p` with term `e`. Returns `some` if successful,
and `none` otherwise. It may update `c`s assignment and list of contraints to be
Expand All @@ -97,16 +121,8 @@ private partial def matchArgs? (c : Choice) (p : Expr) (e : Expr) : OptionT Goal
if !p.isApp then return c -- Done
let pArg := p.appArg!
let eArg := e.appArg!
let goFn c := matchArgs? c p.appFn! e.appFn!
if isPatternDontCare pArg then
goFn c
else if pArg.isBVar then
goFn (← assign? c pArg.bvarIdx! eArg)
else if let some pArg := groundPattern? pArg then
guard (← isEqv pArg eArg)
goFn c
else
goFn { c with cnstrs := .match pArg eArg :: c.cnstrs }
let c ← matchArg? c pArg eArg
matchArgs? c p.appFn! e.appFn!

/--
Matches pattern `p` with term `e` with respect to choice `c`.
Expand All @@ -127,9 +143,39 @@ private partial def processMatch (c : Choice) (p : Expr) (e : Expr) : M Unit :=
&& eqvFunctions pFn curr.getAppFn
&& curr.getAppNumArgs == numArgs then
if let some c ← matchArgs? c p curr |>.run then
let gen := n.generation
let c := { c with gen := Nat.max gen c.gen }
modify fun s => { s with choiceStack := c :: s.choiceStack }
pushChoice (c.updateGen n.generation)
curr ← getNext curr
if isSameExpr curr e then break

/--
Matches offset pattern `pArg+k` with term `e` with respect to choice `c`.
-/
private partial def processOffset (c : Choice) (pArg : Expr) (k : Nat) (e : Expr) : M Unit := do
let maxGeneration ← getMaxGeneration
let mut curr := e
repeat
let n ← getENode curr
if n.generation <= maxGeneration then
if let some (eArg, k') ← isOffset? curr |>.run then
if k' < k then
let c := c.updateGen n.generation
pushChoice { c with cnstrs := .offset pArg (k - k') eArg :: c.cnstrs }
else if k' == k then
if let some c ← matchArg? c pArg eArg |>.run then
pushChoice (c.updateGen n.generation)
else if k' > k then
let eArg' := mkNatAdd eArg (mkNatLit (k' - k))
let eArg' ← shareCommon (← canon eArg')
internalize eArg' n.generation
if let some c ← matchArg? c pArg eArg' |>.run then
pushChoice (c.updateGen n.generation)
else if let some k' ← evalNat curr |>.run then
if k' >= k then
let eArg' := mkNatLit (k' - k)
let eArg' ← shareCommon (← canon eArg')
internalize eArg' n.generation
if let some c ← matchArg? c pArg eArg' |>.run then
pushChoice (c.updateGen n.generation)
curr ← getNext curr
if isSameExpr curr e then break

Expand Down Expand Up @@ -224,6 +270,7 @@ private partial def processChoices : M Unit := do
match c.cnstrs with
| [] => instantiateTheorem c
| .match p e :: cnstrs => processMatch { c with cnstrs } p e
| .offset p k e :: cnstrs => processOffset { c with cnstrs } p k e
| .continue p :: cnstrs => processContinue { c with cnstrs } p
processChoices

Expand Down
45 changes: 42 additions & 3 deletions src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,45 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Init.Grind.Util
import Lean.HeadIndex
import Lean.PrettyPrinter
import Lean.Util.FoldConsts
import Lean.Util.CollectFVars
import Lean.Meta.Basic
import Lean.Meta.InferType
import Lean.Meta.Tactic.Grind.Util

namespace Lean.Meta.Grind

def mkOffsetPattern (pat : Expr) (k : Nat) : Expr :=
mkApp2 (mkConst ``Grind.offset) pat (mkRawNatLit k)

private def detectOffsets (pat : Expr) : MetaM Expr := do
let pre (e : Expr) := do
if e == pat then
-- We only consider nested offset patterns
return .continue e
else match e with
| .letE .. | .lam .. | .forallE .. => return .done e
| _ =>
let some (e, k) ← isOffset? e
| return .continue e
if k == 0 then return .continue e
return .continue <| mkOffsetPattern e k
Core.transform pat (pre := pre)

def isOffsetPattern? (pat : Expr) : Option (Expr × Nat) := Id.run do
let_expr Grind.offset pat k := pat | none
let .lit (.natVal k) := k | none
return some (pat, k)

def preprocessPattern (pat : Expr) : MetaM Expr := do
let pat ← instantiateMVars pat
let pat ← unfoldReducible pat
let pat ← detectOffsets pat
return pat

inductive Origin where
/-- A global declaration in the environment. -/
| decl (declName : Name)
Expand Down Expand Up @@ -202,6 +232,12 @@ private def getPatternFunMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) :=
private partial def go (pattern : Expr) (root := false) : M Expr := do
if root && !pattern.hasLooseBVars then
throwError "invalid pattern, it does not have pattern variables"
if let some (e, k) := isOffsetPattern? pattern then
let e ← goArg e (isSupport := false)
if e == dontCare then
return dontCare
else
return mkOffsetPattern e k
let some f := getPatternFn? pattern
| throwError "invalid pattern, (non-forbidden) application expected"
assert! f.isConst || f.isFVar
Expand All @@ -211,7 +247,11 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
for i in [:args.size] do
let arg := args[i]!
let isSupport := supportMask[i]?.getD false
let arg ← if !arg.hasLooseBVars then
args := args.set! i (← goArg arg isSupport)
return mkAppN f args
where
goArg (arg : Expr) (isSupport : Bool) : M Expr := do
if !arg.hasLooseBVars then
if arg.hasMVar then
pure dontCare
else
Expand All @@ -230,8 +270,6 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
go arg
else
pure dontCare
args := args.set! i arg
return mkAppN f args

def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.HashSet Nat) := do
let (patterns, s) ← patterns.mapM go |>.run {}
Expand Down Expand Up @@ -390,6 +428,7 @@ def mkEMatchEqTheorem (declName : Name) : MetaM EMatchTheorem := do
let info ← getConstInfo declName
let (numParams, patterns) ← forallTelescopeReducing info.type fun xs type => do
let_expr Eq _ lhs _ := type | throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}"
let lhs ← preprocessPattern lhs
return (xs.size, [lhs.abstract xs])
mkEMatchTheorem declName numParams patterns

Expand Down
8 changes: 7 additions & 1 deletion src/Lean/Meta/Tactic/Grind/Main.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Grind.Lemmas
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Simp.Simproc
import Lean.Meta.Tactic.Grind.RevertAll
import Lean.Meta.Tactic.Grind.PropagatorAttr
import Lean.Meta.Tactic.Grind.Proj
Expand Down Expand Up @@ -34,12 +35,17 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do
prop e
}

private def getGrindSimprocs : MetaM Simprocs := do
let s ← grindNormSimprocExt.getSimprocs
let s ← addDoNotSimp s
return s

def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) (fallback : Fallback) : MetaM α := do
let scState := ShareCommon.State.mk _
let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False)
let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True)
let thms ← grindNormExt.getTheorems
let simprocs := #[(← addDoNotSimp (← grindNormSimprocExt.getSimprocs))]
let simprocs := #[(← getGrindSimprocs), (← Simp.getSEvalSimprocs)]
let simp ← Simp.mkContext
(config := { arith := true })
(simpTheorems := #[thms])
Expand Down
87 changes: 87 additions & 0 deletions tests/lean/run/grind_offset.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
opaque g : Nat → Nat

@[simp] def f (a : Nat) :=
match a with
| 0 => 10
| x+1 => g (f x)

set_option trace.grind.ematch.pattern true
set_option trace.grind.ematch.instance true
set_option trace.grind.assert true

/--
info: [grind.ematch.pattern] f.eq_2: [f (Lean.Grind.offset #0 (1))]
-/
#guard_msgs in
grind_pattern f.eq_2 => f (x + 1)


/--
info: [grind.assert] f (y + 1) = a
[grind.assert] ¬a = g (f y)
[grind.ematch.instance] f.eq_2: f y.succ = g (f y)
[grind.assert] f (y + 1) = g (f y)
-/
#guard_msgs (info) in
example : f (y + 1) = a → a = g (f y):= by
grind

/--
info: [grind.assert] f 1 = a
[grind.assert] ¬a = g (f 0)
[grind.ematch.instance] f.eq_2: f (Nat.succ 0) = g (f 0)
[grind.assert] f 1 = g (f 0)
-/
#guard_msgs (info) in
example : f 1 = a → a = g (f 0) := by
grind

/--
info: [grind.assert] f 10 = a
[grind.assert] ¬a = g (f 9)
[grind.ematch.instance] f.eq_2: f (Nat.succ 8) = g (f 8)
[grind.ematch.instance] f.eq_2: f (Nat.succ 9) = g (f 9)
[grind.assert] f 9 = g (f 8)
[grind.assert] f 10 = g (f 9)
-/
#guard_msgs (info) in
example : f 10 = a → a = g (f 9) := by
grind

/--
info: [grind.assert] f (c + 2) = a
[grind.assert] ¬a = g (g (f c))
[grind.ematch.instance] f.eq_2: f (c + 1).succ = g (f (c + 1))
[grind.assert] f (c + 2) = g (f (c + 1))
[grind.ematch.instance] f.eq_2: f c.succ = g (f c)
[grind.assert] f (c + 1) = g (f c)
-/
#guard_msgs (info) in
example : f (c + 2) = a → a = g (g (f c)) := by
grind

@[simp] def foo (a : Nat) :=
match a with
| 0 => 10
| 1 => 10
| a+2 => g (foo a)

/--
info: [grind.ematch.pattern] foo.eq_3: [foo (Lean.Grind.offset #0 (2))]
-/
#guard_msgs in
grind_pattern foo.eq_3 => foo (a_2 + 2)

-- The instance is correctly found in the following example.
-- TODO: to complete the proof, we need linear arithmetic support to prove that `b + 2 = c + 1`.
/--
info: [grind.assert] foo (c + 1) = a
[grind.assert] c = b + 1
[grind.assert] ¬a = g (foo b)
[grind.ematch.instance] foo.eq_3: foo b.succ.succ = g (foo b)
[grind.assert] foo (b + 2) = g (foo b)
-/
#guard_msgs (info) in
example : foo (c + 1) = a → c = b + 1 → a = g (foo b) := by
fail_if_success grind
sorry

0 comments on commit dc5c809

Please sign in to comment.