Skip to content

Commit

Permalink
Began implementing well-formed predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
JOSHCLUNE committed Jan 6, 2025
1 parent 1966d39 commit 6f70ba4
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 64 deletions.
60 changes: 43 additions & 17 deletions Auto/IR/SMT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ mutual
| sConst : SpecConst → STerm
| bvar : Nat → STerm -- De bruijin index
| qIdApp : QualIdent → Array STerm → STerm -- Application of function symbol to array of terms
| testerApp : String → STerm → STerm -- Application of a datatype tester of the form `(_ is ctor)` to a term
| letE : (name : String) → (binding : STerm) → (body : STerm) → STerm
| forallE : (name : String) → (binderType : SSort) → (body : STerm) → STerm
| existE : (name : String) → (binderType : SSort) → (body : STerm) → STerm
Expand Down Expand Up @@ -155,6 +156,8 @@ private partial def STerm.toStringAux : STerm → List SIdent → String
let intro := s!"({si} "
let tail := String.intercalate " " (STerm.toStringAux a binders :: goQIdApp as binders)
intro ++ tail ++ ")"
| .testerApp ctorName s, binders =>
"((_ is " ++ ctorName ++ ") " ++ (STerm.toStringAux s binders) ++ ")"
| .letE name binding body, binders =>
let binders := (SIdent.symb name) :: binders
let intro := s!"(let ({SIdent.symb name} "
Expand Down Expand Up @@ -228,7 +231,7 @@ structure ConstrDecl where
name : String
selDecls : Array (String × SSort)

private def ConstrDecl.toString : ConstrDecl → Array SIdent → String
def ConstrDecl.toString : ConstrDecl → Array SIdent → String
| ⟨name, selDecls⟩, binders =>
let pre := s!"({SIdent.symb name}"
let selDecls := selDecls.map (fun (name, sort) => s!"({SIdent.symb name} " ++ SSort.toString sort binders ++ ")")
Expand All @@ -241,7 +244,7 @@ structure DatatypeDecl where
params : Array String
cstrDecls : Array ConstrDecl

private def DatatypeDecl.toString : DatatypeDecl → String := fun ⟨params, cstrDecls⟩ =>
def DatatypeDecl.toString : DatatypeDecl → String := fun ⟨params, cstrDecls⟩ =>
let scstrDecls := cstrDecls.map (fun d => ConstrDecl.toString d (params.map SIdent.symb))
let scstrDecls := "(" ++ String.intercalate " " scstrDecls.toList ++ ")"
if params.size == 0 then
Expand Down Expand Up @@ -376,6 +379,8 @@ section

-- Type of (identifiers in higher-level logic)
variable (ω : Type) [BEq ω] [Hashable ω]
-- Type of (sorts in higher-level logic)
variable (φ : Type) [BEq φ] [Hashable φ]

/--
The main purpose of this state is for name generation
Expand Down Expand Up @@ -403,41 +408,47 @@ section
-- been used for `k` times (`k > 0`), return `n' ++ s!"_{k - 1}"`.
-- `usedNames` records the `k - 1` for each `n'`
usedNames : Std.HashMap String Nat := {}
-- Map from SMT sorts to the names of their corresponding well-formed predicates.
-- If an SMT sort's well-formed predicate would be equivalent to `True`, no
-- well-formed predicate needs to be created, so `wfPredicatesMap` maps that sort
-- to `none`
wfPredicatesMap : Std.HashMap φ (Option String) := {}
-- Inverse of `wfPredicates`
wfPredicatesInvMap : Std.HashMap String φ := {}
-- List of commands
commands : Array Command := #[]

abbrev TransM := StateRefT (State ω) MetaM

variable {ω : Type} [BEq ω] [Hashable ω]
abbrev TransM := StateRefT (State ω φ) MetaM

@[always_inline]
instance : Monad (TransM ω) :=
let i := inferInstanceAs (Monad (TransM ω));
instance : Monad (TransM ω φ) :=
let i := inferInstanceAs (Monad (TransM ω φ));
{ pure := i.pure, bind := i.bind }

instance : Inhabited (TransM ω α) where
instance : Inhabited (TransM ω φ α) where
default := fun _ => throw default

variable {ω : Type} [BEq ω] [Hashable ω] [ToString ω]
variable {φ : Type} [BEq φ] [Hashable φ] [ToString φ]

@[inline] def TransM.run (x : TransM ω α) (s : State ω := {}) : MetaM (α × State ω) :=
@[inline] def TransM.run (x : TransM ω φ α) (s : State ω φ := {}) : MetaM (α × State ω φ) :=
StateRefT'.run x s

@[inline] def TransM.run' (x : TransM ω α) (s : State ω := {}) : MetaM α :=
@[inline] def TransM.run' (x : TransM ω φ α) (s : State ω φ := {}) : MetaM α :=
Prod.fst <$> StateRefT'.run x s

#genMonadState (TransM ω)
#genMonadState (TransM ω φ)

def getMapSize : TransM ω Nat := do
def getMapSize : TransM ω φ Nat := do
let size := (← getH2lMap).size
assert! ((← getL2hMap).size == size)
return size

def hIn (e : ω) : TransM ω Bool := do
def hIn (e : ω) : TransM ω φ Bool := do
return (← getH2lMap).contains e

/- Note that this function will add the processed name to `usedNames` -/
def processSuggestedName (nameSuggestion : String) : TransM ω String := do
def processSuggestedName (nameSuggestion : String) : TransM ω φ String := do
let mut preName := nameSuggestion.map (fun c => if allowed c then c else '_')
if preName.all (fun c => c == '_') then
preName := "pl_" ++ preName
Expand All @@ -462,13 +473,14 @@ section


/- Generate names that does not correspond to high-level construct -/
partial def disposableName (nameSuggestion : String) : TransM ω String := processSuggestedName nameSuggestion
partial def disposableName (nameSuggestion : String) : TransM ω φ String := processSuggestedName nameSuggestion

/--
Turn high-level construct into low-level symbol
Note that this function is idempotent
-/
partial def h2Symb (cstr : ω) (nameSuggestion : Option String) : TransM ω String := do
partial def h2Symb (cstr : ω) (nameSuggestion : Option String) : TransM ω φ String := do
trace[auto.lamFOL2SMT] "Calling h2Symb on {cstr} with nameSuggestion {nameSuggestion}"
let l2hMap ← getL2hMap
let h2lMap ← getH2lMap
if let .some name := h2lMap.get? cstr then
Expand All @@ -480,7 +492,21 @@ section
setH2lMap (h2lMap.insert cstr name)
return name

def addCommand (c : Command) : TransM ω Unit := do
/-- Like `hySymb` but produces names for well-formed predicates of sort `s` (of type `φ`) rather than of
constructs (of type `ω`) -/
partial def h2SymbWf (s : φ) (nameSuggestion : Option String) : TransM ω φ (Option String) := do
let wfPredicatesMap ← getWfPredicatesMap
let wfPredicatesInvMap ← getWfPredicatesInvMap
if let some name := wfPredicatesMap.get? s then
return name
let .some nameSuggestion := nameSuggestion
| throwError "{decl_name%} :: Fresh well-formed predicate for {s} without name suggestion"
let name ← processSuggestedName nameSuggestion
setWfPredicatesInvMap (wfPredicatesInvMap.insert name s)
setWfPredicatesMap (wfPredicatesMap.insert s name)
return name

def addCommand (c : Command) : TransM ω φ Unit := do
let commands ← getCommands
setCommands (commands.push c)

Expand Down
2 changes: 2 additions & 0 deletions Auto/Parser/LexInit.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def underscore : ERE := .ofStr "_"

def SMTforall : ERE := .ofStr "forall"
def SMTexists : ERE := .ofStr "exists"
def SMTlambda : ERE := .ofStr "lambda" -- This is not part of SMT-lib 2.6 but can be output by cvc5's hints
def SMTlet : ERE := .ofStr "let"

/-- Special constants -/
Expand Down Expand Up @@ -137,6 +138,7 @@ def term : ERE := .plus #[
specConst,
.attr SMTforall "forall",
.attr SMTexists "exists",
.attr SMTlambda "lambda",
.attr SMTlet "let",
.attr lparen "(",
.attr rparen ")",
Expand Down
45 changes: 40 additions & 5 deletions Auto/Parser/SMTParser.lean
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def LexVal.ofString (s : String) (attr : String) : LexVal :=
| "reserved" => .reserved s
| "forall" => .reserved "forall"
| "exists" => .reserved "exists"
| "lambda" => .reserved "lambda"
| "let" => .reserved "let"
| "_" => .underscore
| _ => panic! s!"LexVal.ofString :: {repr attr} is not a valid attribute"
Expand Down Expand Up @@ -179,9 +180,9 @@ def lexTerm [Monad m] [Lean.MonadError m] (s : String) (p : String.Pos)
match (SMT.lexiconADFA.getAttrs state).toList with
| [attr] => pure attr
| [attr1, attr2] =>
if attr1 == "forall" || attr1 == "exists" || attr1 == "let" || attr1 == "_" then pure attr1
else if attr2 == "forall" || attr2 == "exists" || attr2 == "let" || attr2 == "_" then pure attr2
else throwError "parseTerm :: Attribute conflict not caused by forall, exists, let, or _"
if attr1 == "forall" || attr1 == "exists" || attr1 == "lambda" || attr1 == "let" || attr1 == "_" then pure attr1
else if attr2 == "forall" || attr2 == "exists" || attr2 == "lambda" || attr2 == "let" || attr2 == "_" then pure attr2
else throwError "parseTerm :: Attribute conflict not caused by forall, exists, lambda, let, or _"
| _ => throwError "parseTerm :: Invalid number of attributes"

p := matched.stopPos
Expand All @@ -196,7 +197,7 @@ def lexTerm [Monad m] [Lean.MonadError m] (s : String) (p : String.Pos)
-- Too many right parentheses
return .malformed
else
let final := pstk.back
let final := pstk.back!
pstk := pstk.pop
if pstk.size == 0 then
return .complete (.app final) p
Expand Down Expand Up @@ -293,14 +294,15 @@ partial def getExplicitForallArgumentTypes (e : Expr) : List Expr :=
| Expr.forallE _ _t b _ => getExplicitForallArgumentTypes b -- Skip over t because this binder is implicit
| _ => []

-- **TODO** Generalize this to make it possible to impose arbitrary constraints (which is helpful for hints like `(= _wfNat (lambda ((x Int)) (>= x 0)))`
inductive ParseTermConstraint
| mustBeProp
| mustBeBool
| noConstraint

open ParseTermConstraint

/-- A helper function for `parseForall` and `parseExists`
/-- A helper function for `parseForall`, `parseExists`, and `parseLambda`
When parsing the arguments of SMT forall and exists expressions, the SMT type "Bool" can appear, which sometimes must be interpreted
as `Prop` and sometimes must be interpreted as `Bool`. In `parseForall` and `parseExists`, if there are `x` "Bool" binders, then there
Expand Down Expand Up @@ -419,6 +421,38 @@ partial def parseExists (vs : List Term) (symbolMap : Std.HashMap String Expr) :
continue
throwError "parseExists :: Failed to parse exists expression with vs: {vs}"

partial def parseLambdaBodyWithSortedVars (vs : List Term) (sortedVars : Array (String × Expr))
(symbolMap : Std.HashMap String Expr) (lambdaBody : Term) : MetaM Expr := do
withLocalDeclsD (sortedVars.map fun (n, ty) => (n.toName, fun _ => pure ty)) fun _ => do
let lctx ← getLCtx
let mut symbolMap := symbolMap
let mut sortedVarDecls := #[]
for sortedVar in sortedVars do
let some sortedVarDecl := lctx.findFromUserName? sortedVar.1.toName
| throwError "parseForall :: Unknown sorted var name {sortedVar.1} (parseForall input: {vs})"
symbolMap := symbolMap.insert sortedVar.1 (mkFVar sortedVarDecl.fvarId)
sortedVarDecls := sortedVarDecls.push sortedVarDecl
let body ← parseTerm lambdaBody symbolMap mustBeProp
Meta.mkLambdaFVars (sortedVarDecls.map (fun decl => mkFVar decl.fvarId)) body

partial def parseLambda (vs : List Term) (symbolMap : Std.HashMap String Expr) : MetaM Expr := do
let [app sortedVars, existsBody] := vs
| throwError "parseLambda :: Unexpected input list {vs}"
let sortedVars ← sortedVars.mapM (fun sv => parseSortedVar sv symbolMap)
let sortedVarsWithIndices := sortedVars.mapFinIdx (fun idx val => (val, idx))
let mut curPropBoolChoice := some $ (sortedVarsWithIndices.filter (fun ((_, t), _) => t.isProp)).map (fun (_, idx) => (idx, false))
let mut possibleSortedVars := #[]
while curPropBoolChoice.isSome do
let (nextSortedVars, nextCurPropBoolChoice) := getNextSortedVars sortedVars curPropBoolChoice.get!
possibleSortedVars := possibleSortedVars.push nextSortedVars
curPropBoolChoice := nextCurPropBoolChoice
for sortedVars in possibleSortedVars do
try
return ← parseLambdaBodyWithSortedVars vs sortedVars symbolMap existsBody
catch _ =>
continue
throwError "parseLambda :: Failed to parse exists expression with vs: {vs}"

/-- Given a varBinding of the form `(symbol value)` returns the string of the symbol, the type of the value, and the value itself -/
partial def parseVarBinding (varBinding : Term) (symbolMap : Std.HashMap String Expr) : MetaM (String × Expr × Expr) := do
match varBinding with
Expand Down Expand Up @@ -558,6 +592,7 @@ partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTer
match vs.toList with
| atom (reserved "forall") :: restVs => parseForall restVs symbolMap
| atom (reserved "exists") :: restVs => parseExists restVs symbolMap
| atom (reserved "lambda") :: restVs => parseLambda restVs symbolMap
| atom (reserved "let") :: restVs => parseLet restVs symbolMap parseTermConstraint
| atom (symb "=>") :: restVs => parseImplication restVs symbolMap
| app #[atom underscore, atom (symb "is"), ctor] :: [testerArg] =>
Expand Down
9 changes: 7 additions & 2 deletions Auto/Tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI
| _ => throwError "runAuto :: Unexpected error")
let sni : SMT.SMTNamingInfo :=
{tyVal := (← LamReif.getTyVal), varVal := (← LamReif.getVarVal), lamEVarTy := (← LamReif.getLamEVarTy)}
let ((commands, validFacts, l2hMap, selInfos), state) ← (lamFOL2SMTWithExtraInfo sni lamVarTy lamEVarTy exportLamTerms exportInds).run
let ((commands, validFacts, l2hMap, wfPredicatesInvMap, selInfos), state) ←
(lamFOL2SMTWithExtraInfo sni lamVarTy lamEVarTy exportLamTerms exportInds).run
for cmd in commands do
trace[auto.smt.printCommands] "{cmd}"
if (auto.smt.save.get (← getOptions)) then
Expand Down Expand Up @@ -475,14 +476,18 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI
let vderiv ← LamReif.collectDerivFor (.valid [] t)
unsatCoreDerivLeafStrings := unsatCoreDerivLeafStrings ++ vderiv.collectLeafStrings
trace[auto.smt.unsatCore.deriv] "|valid_fact_{id}| : {vderiv}"
-- **Build symbolPrecMap using l2hMap and selInfos**
-- **Build symbolPrecMap using l2hMap, wfPredicatesInvMap, and selInfos**
let (preprocessFacts, theoryLemmas, instantiations, computationLemmas, polynomialLemmas, rewriteFacts) := solverHints
let mut symbolMap : Std.HashMap String Expr := Std.HashMap.empty
for (varName, varAtom) in l2hMap.toArray do
let varLeanExp ←
SMT.withExprValuation sni state.h2lMap (fun tyValMap varValMap etomValMap => do
SMT.LamAtomic.toLeanExpr tyValMap varValMap etomValMap varAtom)
symbolMap := symbolMap.insert varName varLeanExp
for (wfPredicateName, wfPredicateSort) in wfPredicatesInvMap.toArray do
let ty ← SMT.withExprValuation sni state.h2lMap (fun tyValMap _ _ => Lam2D.interpLamSortAsUnlifted tyValMap wfPredicateSort)
let tyPred := .lam .anonymous ty (mkConst ``True) .default -- Interpret `_wf_α` as `fun _ : α => True`
symbolMap := symbolMap.insert wfPredicateName tyPred
/- `selectorArr` has entries containing:
- The name of an SMT selector function
- The constructor it is a selector for
Expand Down
Loading

0 comments on commit 6f70ba4

Please sign in to comment.