Skip to content

Commit

Permalink
Added Nat.sub and ite support to SMTParser
Browse files Browse the repository at this point in the history
  • Loading branch information
JOSHCLUNE committed Jan 9, 2025
1 parent 650f7f5 commit 008a6c5
Showing 1 changed file with 54 additions and 17 deletions.
71 changes: 54 additions & 17 deletions Auto/Parser/SMTParser.lean
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ inductive SymbolInput
| TwoExactNoConstraint -- Used for symbols like `<` that take in exactly two nonProp/nonBool arguments
| TwoExactEq -- Specifically used for `=` which can invoke Prop typing constraints if a Prop and Bool are equated
| Minus -- Minus is left-associative when given ≥ 2 arguments but is also used for unary negation
| Ite -- Used for `ite` which takes in exactly three arguments

open SymbolInput

Expand All @@ -250,12 +251,14 @@ def smtSymbolToLeanName (s : String) : List (Name × SymbolInput) :=
| ">=" => [(``GE.ge, TwoExactNoConstraint)]
| "+" => [(``HAdd.hAdd, LeftAssocNoConstraint)]
| "-" => [(``HSub.hSub, Minus)] -- Minus is left-associative when given ≥ 2 arguments but is also used for unary negation
| "nsub" => [(``Nat.sub, Minus)]
| "*" => [(``HMul.hMul, LeftAssocNoConstraint)]
| "/" => [(``HDiv.hDiv, LeftAssocNoConstraint)]
| "or" => [(``Or, LeftAssocAllProp), (``or, LeftAssocAllBool)]
| "and" => [(``And, LeftAssocAllProp), (``and, LeftAssocAllBool)]
| "not" => [(``Not, UnaryProp), (``not, UnaryBool)]
| "=" => [(``Eq, TwoExactEq)]
| "ite" => [(``Ite, Ite)]
| _ => []

def builtInSymbolMap : Std.HashMap String Expr :=
Expand All @@ -265,6 +268,20 @@ def builtInSymbolMap : Std.HashMap String Expr :=
let map := map.insert "Bool" (.sort .zero)
let map := map.insert "false" (mkConst ``False)
let map := map.insert "true" (mkConst ``True)
let map := map.insert "<" (mkConst ``LT.lt)
let map := map.insert "<=" (mkConst ``LE.le)
let map := map.insert ">" (mkConst ``GT.gt)
let map := map.insert ">=" (mkConst ``GE.ge)
let map := map.insert "+" (mkConst ``HAdd.hAdd)
let map := map.insert "-" (mkConst ``HSub.hSub)
let map := map.insert "nsub" (mkConst ``Nat.sub)
let map := map.insert "*" (mkConst ``HMul.hMul)
let map := map.insert "/" (mkConst ``HDiv.hDiv)
let map := map.insert "or" (mkConst ``Or)
let map := map.insert "and" (mkConst ``And)
let map := map.insert "not" (mkConst ``Not)
let map := map.insert "=" (mkConst ``Eq)
let map := map.insert "ite" (mkConst ``Ite)
map

/-- Given an expression `∀ x1 : t1, x2 : t2, ... xn : tn, b`, returns `[t1, t2, ..., tn]`. If the given expression is not
Expand Down Expand Up @@ -336,14 +353,14 @@ def correctType (e : Expr) (parseTermConstraint : ParseTermConstraint) : MetaM E
let eType ← inferType e
match parseTermConstraint with
| noConstraint =>
if eType == mkConst ``Nat then mkAppM ``Int.ofNat #[e]
if ← isDefEq eType (mkConst ``Nat) then mkAppM ``Int.ofNat #[e]
else return e
| expectedType t =>
if eType == t then return e
else if eType.isProp && t == mkConst ``Bool then whnf $ ← mkAppOptM ``decide #[some e, none]
else if eType == mkConst ``Bool && t.isProp then whnf $ ← mkAppM ``eq_true #[e]
else if eType == mkConst ``Nat && t == mkConst ``Int then return ← mkAppM ``Int.ofNat #[e]
else if eType == mkConst ``Int && t == mkConst ``Nat then return ← mkAppM ``Int.natAbs #[e]
| expectedType t => do
if ← isDefEq eType t then return e
else if eType.isProp && (← isDefEq t (mkConst ``Bool)) then whnf $ ← mkAppOptM ``decide #[some e, none]
else if (← isDefEq eType (mkConst ``Bool)) && t.isProp then whnf $ ← mkAppM ``eq_true #[e]
else if (← isDefEq eType (mkConst ``Nat)) && (← isDefEq t (mkConst ``Int)) then return ← mkAppM ``Int.ofNat #[e]
else if (← isDefEq eType (mkConst ``Int)) && (← isDefEq t (mkConst ``Nat)) then return ← mkAppM ``Int.natAbs #[e]
else throwError "correctType :: {e} is parsed as {eType} which is not a {t}"

mutual
Expand All @@ -359,11 +376,12 @@ partial def parseSortedVar (sortedVar : Term) (symbolMap : Std.HashMap String Ex
match parseTermConstraint with
| noConstraint => return (varSymbol, varTypeExp)
| expectedType t =>
if varTypeExp == t then return (varSymbol, varTypeExp)
else if varTypeExp.isProp && t == mkConst ``Bool then return (varSymbol, t)
else if varTypeExp == mkConst ``Bool && t.isProp then return (varSymbol, t)
else if varTypeExp == mkConst ``Nat && t == mkConst ``Int then return (varSymbol, t)
else if varTypeExp == mkConst ``Int && t == mkConst ``Nat then return (varSymbol, t)
let mut tAndVarTypeExpCompatible ← isDefEq varTypeExp t
tAndVarTypeExpCompatible := tAndVarTypeExpCompatible || (varTypeExp.isProp && (← isDefEq t (mkConst ``Bool)))
tAndVarTypeExpCompatible := tAndVarTypeExpCompatible || ((← isDefEq varTypeExp (mkConst ``Bool)) && t.isProp)
tAndVarTypeExpCompatible := tAndVarTypeExpCompatible || ((← isDefEq varTypeExp (mkConst ``Nat)) && (← isDefEq t (mkConst ``Int)))
tAndVarTypeExpCompatible := tAndVarTypeExpCompatible || ((← isDefEq varTypeExp (mkConst ``Int)) && (← isDefEq t (mkConst ``Nat)))
if tAndVarTypeExpCompatible then return (varSymbol, t)
else throwError "parseSortedVar :: {sortedVar} is parsed as having type {varTypeExp} which is not the expected type {t}"
| _ => throwError "parseSortedVar :: Failed to parse {sortedVar} as a sortedVar"
| _ => throwError "parseSortedVar :: {sortedVar} is supposed to be a sortedVar, not an atom"
Expand Down Expand Up @@ -477,7 +495,7 @@ partial def parseLambda (vs : List Term) (symbolMap : Std.HashMap String Expr) (
return ← parseLambdaBodyWithSortedVars vs sortedVars symbolMap lambdaBody noConstraint
catch _ =>
continue
throwError "parseLambda :: Failed to parse exists expression with vs: {vs}"
throwError "parseLambda :: Failed to parse lambda expression with vs: {vs}"
| expectedType t =>
let lambdaArgTypes := (getExplicitForallArgumentTypes t).toArray
if lambdaArgTypes.size != sortedVars.size then
Expand Down Expand Up @@ -602,13 +620,13 @@ partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTer
let arg ← parseTerm arg symbolMap noConstraint
let argType ← inferType arg
if argType.isProp then mkAppM s1 #[arg]
else if argType == mkConst ``Bool then mkAppM s2 #[arg]
else if (← isDefEq argType (mkConst ``Bool)) then mkAppM s2 #[arg]
else throwError "parseTerm :: {arg} was not be interpreted as Prop or Bool in {e}"
| expectedType t =>
if t.isProp then
let arg ← parseTerm arg symbolMap (expectedType t)
mkAppM s1 #[arg]
else if t == mkConst ``Bool then
else if (← isDefEq t (mkConst ``Bool)) then
let arg ← parseTerm arg symbolMap (expectedType t)
mkAppM s2 #[arg]
else
Expand Down Expand Up @@ -651,14 +669,33 @@ partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTer
| noConstraint => parseLeftAssocApp s1 restVs symbolMap (expectedType (.sort 0)) -- Favor `Prop` interpretation over `Bool` interpretation
| expectedType t =>
if t.isProp then parseLeftAssocApp s1 restVs symbolMap (expectedType t)
else if t == mkConst ``Bool then parseLeftAssocApp s2 restVs symbolMap (expectedType t)
else if (← isDefEq t (mkConst ``Bool)) then parseLeftAssocApp s2 restVs symbolMap (expectedType t)
else throwError "parseTerm :: {e} has a head symbol {s} that does not permit it to have type {t}"
| [(s, Minus)] =>
match restVs with
| [arg] => -- Subtraction is left associative, but if it takes in just one argument, Minus is interpreted as negation
let arg ← parseTerm arg symbolMap parseTermConstraint
mkAppM ``Neg.neg #[arg]
| _ => parseLeftAssocApp s restVs symbolMap parseTermConstraint
| _ =>
if s == ``Nat.sub then
match parseTermConstraint with
| noConstraint =>
parseLeftAssocApp s restVs symbolMap (expectedType (mkConst ``Nat))
| expectedType t =>
if (← isDefEq t (mkConst ``Nat)) then parseLeftAssocApp s restVs symbolMap (expectedType t)
else throwError "parseTerm :: {e} has a head symbol {s} that does not permit it to have type {t}"
else
parseLeftAssocApp s restVs symbolMap parseTermConstraint
| [(_, Ite)] =>
match restVs with
| [cond, thenBranch, elseBranch] =>
-- **TODO** As with `Eq`, we should try to make `thenBranch` and `elseBranch` match each other's type
-- (if parseTermConstraint is `noConstraint`)
let cond ← parseTerm cond symbolMap (expectedType (.sort 0))
let thenBranch ← parseTerm thenBranch symbolMap parseTermConstraint
let elseBranch ← parseTerm elseBranch symbolMap parseTermConstraint
mkAppM ``ite #[cond, thenBranch, elseBranch]
| _ => throwError "parseTerm :: {e} has ite as a head symbol but does not take in exactly three arguments"
| [] =>
match symbolMap.get? s with
| some symbolExp =>
Expand Down

0 comments on commit 008a6c5

Please sign in to comment.