diff --git a/Auto/Parser/SMTParser.lean b/Auto/Parser/SMTParser.lean index 67c134e..fa25c8a 100644 --- a/Auto/Parser/SMTParser.lean +++ b/Auto/Parser/SMTParser.lean @@ -236,6 +236,7 @@ inductive SymbolInput | TwoExactNoConstraint -- Used for symbols like `<` that take in exactly two nonProp/nonBool arguments | TwoExactEq -- Specifically used for `=` which can invoke Prop typing constraints if a Prop and Bool are equated | Minus -- Minus is left-associative when given ≥ 2 arguments but is also used for unary negation +| Ite -- Used for `ite` which takes in exactly three arguments open SymbolInput @@ -250,12 +251,14 @@ def smtSymbolToLeanName (s : String) : List (Name × SymbolInput) := | ">=" => [(``GE.ge, TwoExactNoConstraint)] | "+" => [(``HAdd.hAdd, LeftAssocNoConstraint)] | "-" => [(``HSub.hSub, Minus)] -- Minus is left-associative when given ≥ 2 arguments but is also used for unary negation + | "nsub" => [(``Nat.sub, Minus)] | "*" => [(``HMul.hMul, LeftAssocNoConstraint)] | "/" => [(``HDiv.hDiv, LeftAssocNoConstraint)] | "or" => [(``Or, LeftAssocAllProp), (``or, LeftAssocAllBool)] | "and" => [(``And, LeftAssocAllProp), (``and, LeftAssocAllBool)] | "not" => [(``Not, UnaryProp), (``not, UnaryBool)] | "=" => [(``Eq, TwoExactEq)] + | "ite" => [(``Ite, Ite)] | _ => [] def builtInSymbolMap : Std.HashMap String Expr := @@ -265,6 +268,20 @@ def builtInSymbolMap : Std.HashMap String Expr := let map := map.insert "Bool" (.sort .zero) let map := map.insert "false" (mkConst ``False) let map := map.insert "true" (mkConst ``True) + let map := map.insert "<" (mkConst ``LT.lt) + let map := map.insert "<=" (mkConst ``LE.le) + let map := map.insert ">" (mkConst ``GT.gt) + let map := map.insert ">=" (mkConst ``GE.ge) + let map := map.insert "+" (mkConst ``HAdd.hAdd) + let map := map.insert "-" (mkConst ``HSub.hSub) + let map := map.insert "nsub" (mkConst ``Nat.sub) + let map := map.insert "*" (mkConst ``HMul.hMul) + let map := map.insert "/" (mkConst ``HDiv.hDiv) + let map := map.insert "or" (mkConst ``Or) + let map := map.insert "and" (mkConst ``And) + let map := map.insert "not" (mkConst ``Not) + let map := map.insert "=" (mkConst ``Eq) + let map := map.insert "ite" (mkConst ``Ite) map /-- Given an expression `∀ x1 : t1, x2 : t2, ... xn : tn, b`, returns `[t1, t2, ..., tn]`. If the given expression is not @@ -336,14 +353,14 @@ def correctType (e : Expr) (parseTermConstraint : ParseTermConstraint) : MetaM E let eType ← inferType e match parseTermConstraint with | noConstraint => - if eType == mkConst ``Nat then mkAppM ``Int.ofNat #[e] + if ← isDefEq eType (mkConst ``Nat) then mkAppM ``Int.ofNat #[e] else return e - | expectedType t => - if eType == t then return e - else if eType.isProp && t == mkConst ``Bool then whnf $ ← mkAppOptM ``decide #[some e, none] - else if eType == mkConst ``Bool && t.isProp then whnf $ ← mkAppM ``eq_true #[e] - else if eType == mkConst ``Nat && t == mkConst ``Int then return ← mkAppM ``Int.ofNat #[e] - else if eType == mkConst ``Int && t == mkConst ``Nat then return ← mkAppM ``Int.natAbs #[e] + | expectedType t => do + if ← isDefEq eType t then return e + else if eType.isProp && (← isDefEq t (mkConst ``Bool)) then whnf $ ← mkAppOptM ``decide #[some e, none] + else if (← isDefEq eType (mkConst ``Bool)) && t.isProp then whnf $ ← mkAppM ``eq_true #[e] + else if (← isDefEq eType (mkConst ``Nat)) && (← isDefEq t (mkConst ``Int)) then return ← mkAppM ``Int.ofNat #[e] + else if (← isDefEq eType (mkConst ``Int)) && (← isDefEq t (mkConst ``Nat)) then return ← mkAppM ``Int.natAbs #[e] else throwError "correctType :: {e} is parsed as {eType} which is not a {t}" mutual @@ -359,11 +376,12 @@ partial def parseSortedVar (sortedVar : Term) (symbolMap : Std.HashMap String Ex match parseTermConstraint with | noConstraint => return (varSymbol, varTypeExp) | expectedType t => - if varTypeExp == t then return (varSymbol, varTypeExp) - else if varTypeExp.isProp && t == mkConst ``Bool then return (varSymbol, t) - else if varTypeExp == mkConst ``Bool && t.isProp then return (varSymbol, t) - else if varTypeExp == mkConst ``Nat && t == mkConst ``Int then return (varSymbol, t) - else if varTypeExp == mkConst ``Int && t == mkConst ``Nat then return (varSymbol, t) + let mut tAndVarTypeExpCompatible ← isDefEq varTypeExp t + tAndVarTypeExpCompatible := tAndVarTypeExpCompatible || (varTypeExp.isProp && (← isDefEq t (mkConst ``Bool))) + tAndVarTypeExpCompatible := tAndVarTypeExpCompatible || ((← isDefEq varTypeExp (mkConst ``Bool)) && t.isProp) + tAndVarTypeExpCompatible := tAndVarTypeExpCompatible || ((← isDefEq varTypeExp (mkConst ``Nat)) && (← isDefEq t (mkConst ``Int))) + tAndVarTypeExpCompatible := tAndVarTypeExpCompatible || ((← isDefEq varTypeExp (mkConst ``Int)) && (← isDefEq t (mkConst ``Nat))) + if tAndVarTypeExpCompatible then return (varSymbol, t) else throwError "parseSortedVar :: {sortedVar} is parsed as having type {varTypeExp} which is not the expected type {t}" | _ => throwError "parseSortedVar :: Failed to parse {sortedVar} as a sortedVar" | _ => throwError "parseSortedVar :: {sortedVar} is supposed to be a sortedVar, not an atom" @@ -477,7 +495,7 @@ partial def parseLambda (vs : List Term) (symbolMap : Std.HashMap String Expr) ( return ← parseLambdaBodyWithSortedVars vs sortedVars symbolMap lambdaBody noConstraint catch _ => continue - throwError "parseLambda :: Failed to parse exists expression with vs: {vs}" + throwError "parseLambda :: Failed to parse lambda expression with vs: {vs}" | expectedType t => let lambdaArgTypes := (getExplicitForallArgumentTypes t).toArray if lambdaArgTypes.size != sortedVars.size then @@ -602,13 +620,13 @@ partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTer let arg ← parseTerm arg symbolMap noConstraint let argType ← inferType arg if argType.isProp then mkAppM s1 #[arg] - else if argType == mkConst ``Bool then mkAppM s2 #[arg] + else if (← isDefEq argType (mkConst ``Bool)) then mkAppM s2 #[arg] else throwError "parseTerm :: {arg} was not be interpreted as Prop or Bool in {e}" | expectedType t => if t.isProp then let arg ← parseTerm arg symbolMap (expectedType t) mkAppM s1 #[arg] - else if t == mkConst ``Bool then + else if (← isDefEq t (mkConst ``Bool)) then let arg ← parseTerm arg symbolMap (expectedType t) mkAppM s2 #[arg] else @@ -651,14 +669,33 @@ partial def parseTerm (e : Term) (symbolMap : Std.HashMap String Expr) (parseTer | noConstraint => parseLeftAssocApp s1 restVs symbolMap (expectedType (.sort 0)) -- Favor `Prop` interpretation over `Bool` interpretation | expectedType t => if t.isProp then parseLeftAssocApp s1 restVs symbolMap (expectedType t) - else if t == mkConst ``Bool then parseLeftAssocApp s2 restVs symbolMap (expectedType t) + else if (← isDefEq t (mkConst ``Bool)) then parseLeftAssocApp s2 restVs symbolMap (expectedType t) else throwError "parseTerm :: {e} has a head symbol {s} that does not permit it to have type {t}" | [(s, Minus)] => match restVs with | [arg] => -- Subtraction is left associative, but if it takes in just one argument, Minus is interpreted as negation let arg ← parseTerm arg symbolMap parseTermConstraint mkAppM ``Neg.neg #[arg] - | _ => parseLeftAssocApp s restVs symbolMap parseTermConstraint + | _ => + if s == ``Nat.sub then + match parseTermConstraint with + | noConstraint => + parseLeftAssocApp s restVs symbolMap (expectedType (mkConst ``Nat)) + | expectedType t => + if (← isDefEq t (mkConst ``Nat)) then parseLeftAssocApp s restVs symbolMap (expectedType t) + else throwError "parseTerm :: {e} has a head symbol {s} that does not permit it to have type {t}" + else + parseLeftAssocApp s restVs symbolMap parseTermConstraint + | [(_, Ite)] => + match restVs with + | [cond, thenBranch, elseBranch] => + -- **TODO** As with `Eq`, we should try to make `thenBranch` and `elseBranch` match each other's type + -- (if parseTermConstraint is `noConstraint`) + let cond ← parseTerm cond symbolMap (expectedType (.sort 0)) + let thenBranch ← parseTerm thenBranch symbolMap parseTermConstraint + let elseBranch ← parseTerm elseBranch symbolMap parseTermConstraint + mkAppM ``ite #[cond, thenBranch, elseBranch] + | _ => throwError "parseTerm :: {e} has ite as a head symbol but does not take in exactly three arguments" | [] => match symbolMap.get? s with | some symbolExp =>