Skip to content

Commit

Permalink
Better projection handling, beginning to deal with Nat/Int casts
Browse files Browse the repository at this point in the history
  • Loading branch information
JOSHCLUNE committed Jan 7, 2025
1 parent 6f70ba4 commit cdba40c
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 43 deletions.
58 changes: 41 additions & 17 deletions Auto/Parser/SMTParser.lean
Original file line number Diff line number Diff line change
Expand Up @@ -548,18 +548,19 @@ partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTer
| atom (symb s) =>
match symbolMap.get? s with
| some v =>
let vType ← inferType v
match parseTermConstraint with
| noConstraint => return v
| noConstraint =>
if vType == mkConst ``Nat then mkAppM ``Int.ofNat #[v]
else return v
| mustBeProp =>
let vType ← inferType v
if vType.isProp then
return v
else if vType == mkConst ``Bool then
mkAppM ``Eq #[v, mkConst ``true]
else
throwError "parseTerm :: {e} is parsed as {v} which is not a Prop"
| mustBeBool =>
let vType ← inferType v
if vType == mkConst ``Bool then
return v
else if vType.isProp then
Expand All @@ -569,18 +570,19 @@ partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTer
| none =>
match builtInSymbolMap.get? s with
| some v =>
let vType ← inferType v
match parseTermConstraint with
| noConstraint => return v
| noConstraint =>
if vType == mkConst ``Nat then mkAppM ``Int.ofNat #[v]
else return v
| mustBeProp =>
let vType ← inferType v
if vType.isProp then
return v
else if vType == mkConst ``Bool then
mkAppM ``Eq #[vType, mkConst ``true]
else
throwError "parseTerm :: {e} is parsed as {v} which is not a Prop"
| mustBeBool =>
let vType ← inferType v
if vType == mkConst ``Bool then
return v
else if vType.isProp then
Expand Down Expand Up @@ -721,27 +723,49 @@ partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTer
| some symbolExp =>
let symbolExpType ← inferType symbolExp
let expectedArgTypes := getExplicitForallArgumentTypes symbolExpType
let argConstraints := expectedArgTypes.map
(fun argType =>
if argType.isProp then mustBeProp
else if argType == mkConst ``Bool then mustBeBool
else noConstraint
let args ← (restVs.zip expectedArgTypes).mapM
(fun (t, expectedArgType) => do
if expectedArgType.isProp then
parseTerm t symbolMap mustBeProp
else if expectedArgType == mkConst ``Bool then
parseTerm t symbolMap mustBeBool
else if expectedArgType == mkConst ``Nat then
let arg ← parseTerm t symbolMap noConstraint
let argType ← inferType arg
if argType == mkConst ``Nat then
pure arg
else if argType == mkConst ``Int then
mkAppM ``Int.natAbs #[arg]
else
throwError "parseTerm :: {e} includes term {t} which is parsed as {arg} which is not a Nat"
else if expectedArgType == mkConst ``Int then
let arg ← parseTerm t symbolMap noConstraint
let argType ← inferType arg
if argType == mkConst ``Int then
pure arg
else if argType == mkConst ``Nat then
mkAppM ``Int.ofNat #[arg]
else
throwError "parseTerm :: {e} includes term {t} which is parsed as {arg} which is not an Int"
else
parseTerm t symbolMap noConstraint
)
let args ← (restVs.zip argConstraints).mapM (fun (t, argConstraint) => parseTerm t symbolMap argConstraint)
let res ← mkAppM' symbolExp args.toArray
let resType ← inferType res
match parseTermConstraint with
| noConstraint => mkAppM' symbolExp args.toArray
| noConstraint =>
if resType == mkConst ``Nat then
mkAppM ``Int.ofNat #[res]
else
return res
| mustBeProp =>
let res ← mkAppM' symbolExp args.toArray
let resType ← inferType res
if resType.isProp then
return res
else if resType == mkConst ``Bool then
mkAppM ``Eq #[res, mkConst ``true]
else
throwError "parseTerm :: {e} is parsed as {res} which is not a Prop"
| mustBeBool =>
let res ← mkAppM' symbolExp args.toArray
let resType ← inferType res
if resType == mkConst ``Bool then
return res
else if resType.isProp then
Expand Down
29 changes: 15 additions & 14 deletions Auto/Tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -494,20 +494,21 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI
- The index of the argument it is a selector for
- The mvar used to represent the selector function -/
let mut selectorArr : Array (String × Expr × Nat × Expr) := #[]
for (selName, selCtor, argIdx, datatypeName, selOutputType) in selInfos do
let selCtor ←
SMT.withExprValuation sni state.h2lMap (fun tyValMap varValMap etomValMap => do
SMT.LamAtomic.toLeanExpr tyValMap varValMap etomValMap selCtor)
let selOutputType ←
SMT.withExprValuation sni state.h2lMap (fun tyValMap _ _ => Lam2D.interpLamSortAsUnlifted tyValMap selOutputType)
let selDatatype ←
match symbolMap.get? datatypeName with
| some selDatatype => pure selDatatype
| none => throwError "querySMTForHints :: Could not find the datatype {datatypeName} corresponding to selector {selName}"
let selType := Expr.forallE `x selDatatype selOutputType .default
let selMVar ← Meta.mkFreshExprMVar selType
selectorArr := selectorArr.push (selName, selCtor, argIdx, selMVar)
symbolMap := symbolMap.insert selName selMVar
for (selName, selIsProjection, selCtor, argIdx, datatypeName, selOutputType) in selInfos do
if !selIsProjection then -- Projections already have corresponding values in Lean and therefore don't need to be added to `selectorArr`
let selCtor ←
SMT.withExprValuation sni state.h2lMap (fun tyValMap varValMap etomValMap => do
SMT.LamAtomic.toLeanExpr tyValMap varValMap etomValMap selCtor)
let selOutputType ←
SMT.withExprValuation sni state.h2lMap (fun tyValMap _ _ => Lam2D.interpLamSortAsUnlifted tyValMap selOutputType)
let selDatatype ←
match symbolMap.get? datatypeName with
| some selDatatype => pure selDatatype
| none => throwError "querySMTForHints :: Could not find the datatype {datatypeName} corresponding to selector {selName}"
let selType := Expr.forallE `x selDatatype selOutputType .default
let selMVar ← Meta.mkFreshExprMVar selType
selectorArr := selectorArr.push (selName, selCtor, argIdx, selMVar)
symbolMap := symbolMap.insert selName selMVar
let selectorMVars := selectorArr.map (fun (_, _, _, selMVar) => selMVar)
-- Change the last argument of selectorArr from the mvar used to represent the selector function to its type
selectorArr ← selectorArr.mapM
Expand Down
25 changes: 13 additions & 12 deletions Auto/Translation/LamFOL2SMT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ instance : Ord LamAtomic where

/-- selectorInfo contains:
- The name of the selector
- Whether the selector is a projection
- The constructor that the selector is for
- The index of the argument that it is selecting
- The name of the SMT datatype that the selector is for
- The output type of the selector (full type is an arrow type that takes in
the datatype and returns the output type) -/
abbrev SelectorInfo := String × LamAtomic × Nat × String × LamSort
abbrev SelectorInfo := String × Bool × LamAtomic × Nat × String × LamSort

def LamAtomic.toLeanExpr
(tyValMap varValMap etomValMap : Std.HashMap Nat Expr)
Expand Down Expand Up @@ -190,24 +191,24 @@ mutual
| _ => -- **TODO** This approach does not adequately address mutually inductive datatypes (assumes all ctors are of same datatype, and other issues)
/- **TODO** Determine datatype name by calling `← h2Symb (.sort n) none` rather than using the first selInfo's datatype name
(to ensure we get the correct specific datatype when considering mutually inductive datatypes) -/
let datatypeName := selInfos[0]!.2.2.2.1 -- Guaranteed to not panic because `selInfos.size > 0`
let datatypeName := selInfos[0]!.2.2.2.2.1 -- Guaranteed to not panic because `selInfos.size > 0`
let some sWfConstraintName ← h2SymbWf s $ some ("wf" ++ datatypeName.capitalize)
| throwError "{decl_name%} :: h2SymbWf returned none given {s} even though {s} has a nontrivial well-formed predicate"
trace[auto.lamFOL2SMT] "defineWfConstraint :: {s} has datatype name {datatypeName}"
let mut wfCstrTerms : Array STerm := #[]
-- Gather a list of selector infos for each constructor
let mut ctorInfos : Std.HashMap LamAtomic (List SelectorInfo) := Std.HashMap.empty
for (selName, ctor, argIdx, datatypeName, selOutputType) in selInfos do
for (selName, selIsProjection, ctor, argIdx, datatypeName, selOutputType) in selInfos do
if ctorInfos.contains ctor then
ctorInfos := ctorInfos.modify ctor (fun acc => (selName, ctor, argIdx, datatypeName, selOutputType) :: acc)
ctorInfos := ctorInfos.modify ctor (fun acc => (selName, selIsProjection, ctor, argIdx, datatypeName, selOutputType) :: acc)
else
ctorInfos := ctorInfos.insert ctor [(selName, ctor, argIdx, datatypeName, selOutputType)]
ctorInfos := ctorInfos.insert ctor [(selName, selIsProjection, ctor, argIdx, datatypeName, selOutputType)]
-- Iterate through each constructor to build `wfCstrTerms`
for (ctor, ctorSelInfos) in ctorInfos do
let ctorName ← h2Symb ctor none -- `none` should never cause an error since `ctor` should have been given a symbol when the datatype was defined
let ctorTester := .testerApp ctorName (.bvar 0) -- `.bvar 0` refers to the element of sort `s` being tested
let mut wfSelectorTerms : Array STerm := #[]
for (selName, _, argIdx, datatypeName, selOutputType) in ctorSelInfos do
for (selName, selIsProjection, _, argIdx, datatypeName, selOutputType) in ctorSelInfos do
trace[auto.lamFOL2SMT] "defineWfConstraint :: Examining selector {selName} for datatype {datatypeName} which has output type {selOutputType}"
match ← getWfConstraint sni selOutputType none with
| some selOutputTypeWfConstraint => wfSelectorTerms := wfSelectorTerms.push $ .qStrApp selOutputTypeWfConstraint #[.qStrApp selName #[.bvar 0]]
Expand Down Expand Up @@ -672,19 +673,19 @@ private def lamMutualIndInfo2STermWithInfos (sni : SMTNamingInfo) (mind : Mutual
let lamSortArgTys := s.getArgTys -- `argTys` as `LamSort` rather than `SSort`
let mut selDecls := #[]
if projs.isSome then
if argTys.length != projInfos.size then
if argTys.length != projInfos.size || lamSortArgTys.length != projInfos.size then
throwError "lamMutualIndInfo2STerm :: Unexpected error"
selDecls := ((Array.mk argTys).zip projInfos).map (fun (argTy, _, name) => (name, argTy))
/- We don't update `selInfos` because `projs` exist for this inductive datatype. Since `projs` exists,
the selector function we want will already correspond to an existing projection function in Lean,
meaning we don't need to define a different selector function to have something to map the current
selDecls onto -/
let selDeclsInfos :=
((Array.mk lamSortArgTys).zip projInfos).zipWithIndex.map
(fun ((lamSortArgTy, _, name), idx) => (name, true, tAtomic, idx, sname, lamSortArgTy))
selInfos := selInfos ++ selDeclsInfos
else
selDecls := (Array.mk argTys).zipWithIndex.map (fun (argTy, idx) =>
(ctorname ++ s!"_sel{idx}", argTy))
let selDeclsInfos :=
(Array.mk lamSortArgTys).zipWithIndex.map
(fun (lamSortArgTy, idx) => (ctorname ++ s!"_sel{idx}", tAtomic, idx, sname, lamSortArgTy))
(fun (lamSortArgTy, idx) => (ctorname ++ s!"_sel{idx}", false, tAtomic, idx, sname, lamSortArgTy))
selInfos := selInfos ++ selDeclsInfos
cstrDecls := cstrDecls.push ⟨ctorname, selDecls⟩
infos := infos.push (sname, 0, ⟨#[], cstrDecls⟩)
Expand Down

0 comments on commit cdba40c

Please sign in to comment.