Skip to content

Commit

Permalink
fix: improve typeOccursCheck
Browse files Browse the repository at this point in the history
This commit addresses a Mathlib failure.
  • Loading branch information
leodemoura committed Jun 10, 2024
1 parent fb01c65 commit 72e5528
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 3 deletions.
37 changes: 34 additions & 3 deletions src/Lean/Meta/ExprDefEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,29 @@ partial def check

end CheckAssignmentQuick

/--
Auxiliary function used at `typeOccursCheckImp`.
Given `type`, it tries to eliminate "dependencies". For example, suppose we are trying to
perform the assignment `?m := f (?n a b)` where
```
?n : let k := g ?m; A -> h k ?m -> C
```
If we just perform occurs check `?m` at the type of `?n`, we get a failure, but
we claim these occurrences are ok because the type `?n a b : C`.
In the example above, `typeOccursCheckImp` invokes this function with `n := 2`.
Note that we avoid using `whnf` and `inferType` at `typeOccursCheckImp` to minimize the
performance impact of this extra check.
See test `typeOccursCheckIssue.lean` for an example where this refinement is needed.
The test is derived from a Mathlib file.
-/
private partial def skipAtMostNumBinders (type : Expr) (n : Nat) : Expr :=
match type, n with
| .forallE _ _ b _, n+1 => skipAtMostNumBinders b n
| .mdata _ b, n => skipAtMostNumBinders b n
| .letE _ _ v b _, n => skipAtMostNumBinders (b.instantiate1 v) n
| type, _ => type

/-- `typeOccursCheck` implementation using unsafe (i.e., pointer equality) features. -/
private unsafe def typeOccursCheckImp (mctx : MetavarContext) (mvarId : MVarId) (v : Expr) : Bool :=
if v.hasExprMVar then
Expand All @@ -949,11 +972,19 @@ where
-- this function assumes all assigned metavariables have already been
-- instantiated.
go.run' mctx
visitMVar (mvarId' : MVarId) : Bool :=
visitMVar (mvarId' : MVarId) (numArgs : Nat := 0) : Bool :=
if let some mvarDecl := mctx.findDecl? mvarId' then
occursCheck mvarDecl.type
occursCheck (skipAtMostNumBinders mvarDecl.type numArgs)
else
false
visitApp (e : Expr) : StateM (PtrSet Expr) Bool :=
e.withApp fun f args => do
unless (← args.allM visit) do
return false
if f.isMVar then
return visitMVar f.mvarId! args.size
else
visit f
visit (e : Expr) : StateM (PtrSet Expr) Bool := do
if !e.hasExprMVar then
return true
Expand All @@ -962,7 +993,7 @@ where
else match e with
| .mdata _ b => visit b
| .proj _ _ s => visit s
| .app f a => visit f <&&> visit a
| .app .. => visitApp e
| .lam _ d b _ => visit d <&&> visit b
| .forallE _ d b _ => visit d <&&> visit b
| .letE _ t v b _ => visit t <&&> visit v <&&> visit b
Expand Down
76 changes: 76 additions & 0 deletions tests/lean/run/typeOccursCheckIssue.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
namespace SlimCheck

inductive TestResult (p : Prop) where
| success : PSum Unit p → TestResult p
| gaveUp : Nat → TestResult p
| failure : ¬ p → List String → Nat → TestResult p
deriving Inhabited

/-- Configuration for testing a property. -/
structure Configuration where
numInst : Nat := 100
maxSize : Nat := 100
numRetries : Nat := 10
traceDiscarded : Bool := false
traceSuccesses : Bool := false
traceShrink : Bool := false
traceShrinkCandidates : Bool := false
randomSeed : Option Nat := none
quiet : Bool := false
deriving Inhabited

abbrev Rand := Id

abbrev Gen (α : Type u) := ReaderT (ULift Nat) Rand α

/-- `Testable p` uses random examples to try to disprove `p`. -/
class Testable (p : Prop) where
run (cfg : Configuration) (minimize : Bool) : Gen (TestResult p)

def NamedBinder (_n : String) (p : Prop) : Prop := p

namespace TestResult

def isFailure : TestResult p → Bool
| failure _ _ _ => true
| _ => false

end TestResult

namespace Testable

open TestResult

def runProp (p : Prop) [Testable p] : Configuration → Bool → Gen (TestResult p) := Testable.run

variable {var : String}

def addShrinks (n : Nat) : TestResult p → TestResult p
| TestResult.failure p xs m => TestResult.failure p xs (m + n)
| p => p

instance [Pure m] : Inhabited (OptionT m α) := ⟨(pure none : m (Option α))⟩

class Shrinkable (α : Type u) where
shrink : (x : α) → List α := fun _ ↦ []

class SampleableExt (α : Sort u) where
proxy : Type v
[proxyRepr : Repr proxy]
[shrink : Shrinkable proxy]
sample : Gen proxy
interp : proxy → α

partial def minimizeAux [SampleableExt α] {β : α → Prop} [∀ x, Testable (β x)] (cfg : Configuration)
(var : String) (x : SampleableExt.proxy α) (n : Nat) :
OptionT Gen (Σ x, TestResult (β (SampleableExt.interp x))) := do
let candidates := SampleableExt.shrink.shrink x
for candidate in candidates do
let res ← OptionT.lift <| Testable.runProp (β (SampleableExt.interp candidate)) cfg true
if res.isFailure then
if cfg.traceShrink then
pure () -- slimTrace s!"{var} shrunk to {repr candidate} from {repr x}"
let currentStep := OptionT.lift <| pure <| Sigma.mk candidate (addShrinks (n + 1) res)
let nextStep := minimizeAux cfg var candidate (n + 1)
return ← (nextStep <|> currentStep)
failure

0 comments on commit 72e5528

Please sign in to comment.