From dc5c8097b5ca053dd49b5de246add54c3e999734 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 5 Jan 2025 03:20:17 +0100 Subject: [PATCH] feat: add term offset support to the `grind` E-matching modulo (#6533) This PR adds support to E-matching offset patterns. For example, we want to be able to E-match the pattern `f (#0 + 1)` with term `f (a + 2)`. --- src/Init/Grind/Norm.lean | 3 + src/Init/Grind/Util.lean | 3 + src/Lean/Elab/Tactic/Grind.lean | 6 +- src/Lean/Meta/Tactic/Grind/EMatch.lean | 73 +++++++++++++--- src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean | 45 +++++++++- src/Lean/Meta/Tactic/Grind/Main.lean | 8 +- tests/lean/run/grind_offset.lean | 87 +++++++++++++++++++ 7 files changed, 206 insertions(+), 19 deletions(-) create mode 100644 tests/lean/run/grind_offset.lean diff --git a/src/Init/Grind/Norm.lean b/src/Init/Grind/Norm.lean index 72bc3021ec1a..d911cf634703 100644 --- a/src/Init/Grind/Norm.lean +++ b/src/Init/Grind/Norm.lean @@ -114,4 +114,7 @@ attribute [grind_norm] Nat.le_zero_eq -- GT GE attribute [grind_norm] GT.gt GE.ge +-- Succ +attribute [grind_norm] Nat.succ_eq_add_one + end Lean.Grind diff --git a/src/Init/Grind/Util.lean b/src/Init/Grind/Util.lean index 9c7523d992cb..9a231e318162 100644 --- a/src/Init/Grind/Util.lean +++ b/src/Init/Grind/Util.lean @@ -18,6 +18,9 @@ We use it when adding instances of `match`-equations to prevent them from being -/ def doNotSimp {α : Sort u} (a : α) : α := a +/-- Gadget for representing offsets `t+k` in patterns. -/ +def offset (a b : Nat) : Nat := a + b + set_option pp.proofs true theorem nestedProof_congr (p q : Prop) (h : p = q) (hp : p) (hq : q) : HEq (nestedProof p hp) (nestedProof q hq) := by diff --git a/src/Lean/Elab/Tactic/Grind.lean b/src/Lean/Elab/Tactic/Grind.lean index b9a62e4baed7..d094ae98e4bb 100644 --- a/src/Lean/Elab/Tactic/Grind.lean +++ b/src/Lean/Elab/Tactic/Grind.lean @@ -26,8 +26,10 @@ def elabGrindPattern : CommandElab := fun stx => do let info ← getConstInfo declName forallTelescope info.type fun xs _ => do let patterns ← terms.getElems.mapM fun term => do - let pattern ← instantiateMVars (← elabTerm term none) - let pattern ← Grind.unfoldReducible pattern + let pattern ← elabTerm term none + synthesizeSyntheticMVarsUsingDefault + let pattern ← instantiateMVars pattern + let pattern ← Grind.preprocessPattern pattern return pattern.abstract xs Grind.addEMatchTheorem declName xs.size patterns.toList | _ => throwUnsupportedSyntax diff --git a/src/Lean/Meta/Tactic/Grind/EMatch.lean b/src/Lean/Meta/Tactic/Grind/EMatch.lean index 42031260a0a7..ca6df2b7a07c 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatch.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatch.lean @@ -16,6 +16,8 @@ namespace EMatch inductive Cnstr where | /-- Matches pattern `pat` with term `e` -/ «match» (pat : Expr) (e : Expr) + | /-- Matches offset pattern `pat+k` with term `e` -/ + offset (pat : Expr) (k : Nat) (e : Expr) | /-- This constraint is used to encode multi-patterns. -/ «continue» (pat : Expr) deriving Inhabited @@ -88,6 +90,28 @@ private def eqvFunctions (pFn eFn : Expr) : Bool := (pFn.isFVar && pFn == eFn) || (pFn.isConst && eFn.isConstOf pFn.constName!) +/-- Matches a pattern argument. See `matchArgs?`. -/ +private def matchArg? (c : Choice) (pArg : Expr) (eArg : Expr) : OptionT GoalM Choice := do + if isPatternDontCare pArg then + return c + else if pArg.isBVar then + assign? c pArg.bvarIdx! eArg + else if let some pArg := groundPattern? pArg then + guard (← isEqv pArg eArg) + return c + else if let some (pArg, k) := isOffsetPattern? pArg then + assert! Option.isNone <| isOffsetPattern? pArg + assert! !isPatternDontCare pArg + return { c with cnstrs := .offset pArg k eArg :: c.cnstrs } + else + return { c with cnstrs := .match pArg eArg :: c.cnstrs } + +private def Choice.updateGen (c : Choice) (gen : Nat) : Choice := + { c with gen := Nat.max gen c.gen } + +private def pushChoice (c : Choice) : M Unit := + modify fun s => { s with choiceStack := c :: s.choiceStack } + /-- Matches arguments of pattern `p` with term `e`. Returns `some` if successful, and `none` otherwise. It may update `c`s assignment and list of contraints to be @@ -97,16 +121,8 @@ private partial def matchArgs? (c : Choice) (p : Expr) (e : Expr) : OptionT Goal if !p.isApp then return c -- Done let pArg := p.appArg! let eArg := e.appArg! - let goFn c := matchArgs? c p.appFn! e.appFn! - if isPatternDontCare pArg then - goFn c - else if pArg.isBVar then - goFn (← assign? c pArg.bvarIdx! eArg) - else if let some pArg := groundPattern? pArg then - guard (← isEqv pArg eArg) - goFn c - else - goFn { c with cnstrs := .match pArg eArg :: c.cnstrs } + let c ← matchArg? c pArg eArg + matchArgs? c p.appFn! e.appFn! /-- Matches pattern `p` with term `e` with respect to choice `c`. @@ -127,9 +143,39 @@ private partial def processMatch (c : Choice) (p : Expr) (e : Expr) : M Unit := && eqvFunctions pFn curr.getAppFn && curr.getAppNumArgs == numArgs then if let some c ← matchArgs? c p curr |>.run then - let gen := n.generation - let c := { c with gen := Nat.max gen c.gen } - modify fun s => { s with choiceStack := c :: s.choiceStack } + pushChoice (c.updateGen n.generation) + curr ← getNext curr + if isSameExpr curr e then break + +/-- +Matches offset pattern `pArg+k` with term `e` with respect to choice `c`. +-/ +private partial def processOffset (c : Choice) (pArg : Expr) (k : Nat) (e : Expr) : M Unit := do + let maxGeneration ← getMaxGeneration + let mut curr := e + repeat + let n ← getENode curr + if n.generation <= maxGeneration then + if let some (eArg, k') ← isOffset? curr |>.run then + if k' < k then + let c := c.updateGen n.generation + pushChoice { c with cnstrs := .offset pArg (k - k') eArg :: c.cnstrs } + else if k' == k then + if let some c ← matchArg? c pArg eArg |>.run then + pushChoice (c.updateGen n.generation) + else if k' > k then + let eArg' := mkNatAdd eArg (mkNatLit (k' - k)) + let eArg' ← shareCommon (← canon eArg') + internalize eArg' n.generation + if let some c ← matchArg? c pArg eArg' |>.run then + pushChoice (c.updateGen n.generation) + else if let some k' ← evalNat curr |>.run then + if k' >= k then + let eArg' := mkNatLit (k' - k) + let eArg' ← shareCommon (← canon eArg') + internalize eArg' n.generation + if let some c ← matchArg? c pArg eArg' |>.run then + pushChoice (c.updateGen n.generation) curr ← getNext curr if isSameExpr curr e then break @@ -224,6 +270,7 @@ private partial def processChoices : M Unit := do match c.cnstrs with | [] => instantiateTheorem c | .match p e :: cnstrs => processMatch { c with cnstrs } p e + | .offset p k e :: cnstrs => processOffset { c with cnstrs } p k e | .continue p :: cnstrs => processContinue { c with cnstrs } p processChoices diff --git a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean index 4b03d9021fe5..8b50693e1523 100644 --- a/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean +++ b/src/Lean/Meta/Tactic/Grind/EMatchTheorem.lean @@ -4,15 +4,45 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Leonardo de Moura -/ prelude +import Init.Grind.Util import Lean.HeadIndex import Lean.PrettyPrinter import Lean.Util.FoldConsts import Lean.Util.CollectFVars import Lean.Meta.Basic import Lean.Meta.InferType +import Lean.Meta.Tactic.Grind.Util namespace Lean.Meta.Grind +def mkOffsetPattern (pat : Expr) (k : Nat) : Expr := + mkApp2 (mkConst ``Grind.offset) pat (mkRawNatLit k) + +private def detectOffsets (pat : Expr) : MetaM Expr := do + let pre (e : Expr) := do + if e == pat then + -- We only consider nested offset patterns + return .continue e + else match e with + | .letE .. | .lam .. | .forallE .. => return .done e + | _ => + let some (e, k) ← isOffset? e + | return .continue e + if k == 0 then return .continue e + return .continue <| mkOffsetPattern e k + Core.transform pat (pre := pre) + +def isOffsetPattern? (pat : Expr) : Option (Expr × Nat) := Id.run do + let_expr Grind.offset pat k := pat | none + let .lit (.natVal k) := k | none + return some (pat, k) + +def preprocessPattern (pat : Expr) : MetaM Expr := do + let pat ← instantiateMVars pat + let pat ← unfoldReducible pat + let pat ← detectOffsets pat + return pat + inductive Origin where /-- A global declaration in the environment. -/ | decl (declName : Name) @@ -202,6 +232,12 @@ private def getPatternFunMask (f : Expr) (numArgs : Nat) : MetaM (Array Bool) := private partial def go (pattern : Expr) (root := false) : M Expr := do if root && !pattern.hasLooseBVars then throwError "invalid pattern, it does not have pattern variables" + if let some (e, k) := isOffsetPattern? pattern then + let e ← goArg e (isSupport := false) + if e == dontCare then + return dontCare + else + return mkOffsetPattern e k let some f := getPatternFn? pattern | throwError "invalid pattern, (non-forbidden) application expected" assert! f.isConst || f.isFVar @@ -211,7 +247,11 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do for i in [:args.size] do let arg := args[i]! let isSupport := supportMask[i]?.getD false - let arg ← if !arg.hasLooseBVars then + args := args.set! i (← goArg arg isSupport) + return mkAppN f args +where + goArg (arg : Expr) (isSupport : Bool) : M Expr := do + if !arg.hasLooseBVars then if arg.hasMVar then pure dontCare else @@ -230,8 +270,6 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do go arg else pure dontCare - args := args.set! i arg - return mkAppN f args def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.HashSet Nat) := do let (patterns, s) ← patterns.mapM go |>.run {} @@ -390,6 +428,7 @@ def mkEMatchEqTheorem (declName : Name) : MetaM EMatchTheorem := do let info ← getConstInfo declName let (numParams, patterns) ← forallTelescopeReducing info.type fun xs type => do let_expr Eq _ lhs _ := type | throwError "invalid E-matching equality theorem, conclusion must be an equality{indentExpr type}" + let lhs ← preprocessPattern lhs return (xs.size, [lhs.abstract xs]) mkEMatchTheorem declName numParams patterns diff --git a/src/Lean/Meta/Tactic/Grind/Main.lean b/src/Lean/Meta/Tactic/Grind/Main.lean index 98214895ee8d..8ee85469bdc3 100644 --- a/src/Lean/Meta/Tactic/Grind/Main.lean +++ b/src/Lean/Meta/Tactic/Grind/Main.lean @@ -6,6 +6,7 @@ Authors: Leonardo de Moura prelude import Init.Grind.Lemmas import Lean.Meta.Tactic.Util +import Lean.Meta.Tactic.Simp.Simproc import Lean.Meta.Tactic.Grind.RevertAll import Lean.Meta.Tactic.Grind.PropagatorAttr import Lean.Meta.Tactic.Grind.Proj @@ -34,12 +35,17 @@ def mkMethods (fallback : Fallback) : CoreM Methods := do prop e } +private def getGrindSimprocs : MetaM Simprocs := do + let s ← grindNormSimprocExt.getSimprocs + let s ← addDoNotSimp s + return s + def GrindM.run (x : GrindM α) (mainDeclName : Name) (config : Grind.Config) (fallback : Fallback) : MetaM α := do let scState := ShareCommon.State.mk _ let (falseExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``False) let (trueExpr, scState) := ShareCommon.State.shareCommon scState (mkConst ``True) let thms ← grindNormExt.getTheorems - let simprocs := #[(← addDoNotSimp (← grindNormSimprocExt.getSimprocs))] + let simprocs := #[(← getGrindSimprocs), (← Simp.getSEvalSimprocs)] let simp ← Simp.mkContext (config := { arith := true }) (simpTheorems := #[thms]) diff --git a/tests/lean/run/grind_offset.lean b/tests/lean/run/grind_offset.lean new file mode 100644 index 000000000000..7711ece39be1 --- /dev/null +++ b/tests/lean/run/grind_offset.lean @@ -0,0 +1,87 @@ +opaque g : Nat → Nat + +@[simp] def f (a : Nat) := + match a with + | 0 => 10 + | x+1 => g (f x) + +set_option trace.grind.ematch.pattern true +set_option trace.grind.ematch.instance true +set_option trace.grind.assert true + +/-- +info: [grind.ematch.pattern] f.eq_2: [f (Lean.Grind.offset #0 (1))] +-/ +#guard_msgs in +grind_pattern f.eq_2 => f (x + 1) + + +/-- +info: [grind.assert] f (y + 1) = a +[grind.assert] ¬a = g (f y) +[grind.ematch.instance] f.eq_2: f y.succ = g (f y) +[grind.assert] f (y + 1) = g (f y) +-/ +#guard_msgs (info) in +example : f (y + 1) = a → a = g (f y):= by + grind + +/-- +info: [grind.assert] f 1 = a +[grind.assert] ¬a = g (f 0) +[grind.ematch.instance] f.eq_2: f (Nat.succ 0) = g (f 0) +[grind.assert] f 1 = g (f 0) +-/ +#guard_msgs (info) in +example : f 1 = a → a = g (f 0) := by + grind + +/-- +info: [grind.assert] f 10 = a +[grind.assert] ¬a = g (f 9) +[grind.ematch.instance] f.eq_2: f (Nat.succ 8) = g (f 8) +[grind.ematch.instance] f.eq_2: f (Nat.succ 9) = g (f 9) +[grind.assert] f 9 = g (f 8) +[grind.assert] f 10 = g (f 9) +-/ +#guard_msgs (info) in +example : f 10 = a → a = g (f 9) := by + grind + +/-- +info: [grind.assert] f (c + 2) = a +[grind.assert] ¬a = g (g (f c)) +[grind.ematch.instance] f.eq_2: f (c + 1).succ = g (f (c + 1)) +[grind.assert] f (c + 2) = g (f (c + 1)) +[grind.ematch.instance] f.eq_2: f c.succ = g (f c) +[grind.assert] f (c + 1) = g (f c) +-/ +#guard_msgs (info) in +example : f (c + 2) = a → a = g (g (f c)) := by + grind + +@[simp] def foo (a : Nat) := + match a with + | 0 => 10 + | 1 => 10 + | a+2 => g (foo a) + +/-- +info: [grind.ematch.pattern] foo.eq_3: [foo (Lean.Grind.offset #0 (2))] +-/ +#guard_msgs in +grind_pattern foo.eq_3 => foo (a_2 + 2) + +-- The instance is correctly found in the following example. +-- TODO: to complete the proof, we need linear arithmetic support to prove that `b + 2 = c + 1`. +/-- +info: [grind.assert] foo (c + 1) = a +[grind.assert] c = b + 1 +[grind.assert] ¬a = g (foo b) +[grind.ematch.instance] foo.eq_3: foo b.succ.succ = g (foo b) +[grind.assert] foo (b + 2) = g (foo b) +-/ +#guard_msgs (info) in +example : foo (c + 1) = a → c = b + 1 → a = g (foo b) := by + fail_if_success grind + sorry