Skip to content

Commit

Permalink
Selector support for mutually recursive inductive datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
JOSHCLUNE committed Aug 13, 2024
1 parent dad4d6c commit 6a8c8ff
Showing 1 changed file with 86 additions and 60 deletions.
146 changes: 86 additions & 60 deletions Auto/Tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -774,38 +774,51 @@ def buildSelectorForInhabitedType (selCtor : Expr) (argIdx : Nat) : MetaM Expr :
let ival ← matchConstInduct datatype.getAppFn
(fun _ => throwError "buildSelectorForInhabitedType :: The datatype of {selCtor} ({datatype}) is not a datatype")
(fun ival _ => pure ival)
let mutuallyRecursiveDatatypes ← ival.all.mapM
(fun n => do
let nConst ← Meta.mkAppM' (mkConst n lvls) selCtorParams
matchConstInduct nConst.getAppFn
(fun _ => throwError "buildSelectorForInhabitedType :: Error in gathering InductiveVal for {nConst} which should be mutually recursive with {datatype}")
(fun ival _ => pure (nConst, ival)))
let recursor := (mkConst (.str datatypeName "rec") (selectorOutputUniverseLevel :: lvls))
let mut recursorArgs := selCtorParams
let motive := .lam `_ datatype selectorOutputType .default
-- **TODO** Multiple motives for mutually recursive datatypes
recursorArgs := recursorArgs.push motive
for curCtorIdx in [:ival.ctors.length] do
if curCtorIdx == cval.cidx then
let decls := selCtorFieldTypes.mapIdx fun idx ty => (.str .anonymous ("arg" ++ idx.1.repr), fun prevArgs => pure (ty.instantiate prevArgs))
let nextRecursorArg ←
Meta.withLocalDeclsD decls fun curCtorFields => do
let recursiveFieldMotiveDecls ← curCtorFields.filterMapM
(fun ctorFieldFVar => do
if (← Meta.inferType ctorFieldFVar) == datatype then return some $ (`_, (fun _ => Meta.mkAppM' motive #[ctorFieldFVar]))
else return none)
Meta.withLocalDeclsD recursiveFieldMotiveDecls fun recursiveFieldMotiveFVars => do
Meta.mkLambdaFVars (curCtorFields ++ recursiveFieldMotiveFVars) curCtorFields[argIdx]!
recursorArgs := recursorArgs.push nextRecursorArg
else
let curCtor := mkConst ival.ctors[curCtorIdx]! lvls
let curCtor ← Meta.mkAppOptM' curCtor (selCtorParams.map some)
let curCtorType ← Meta.inferType curCtor
let curCtorFieldTypes := (getForallArgumentTypes curCtorType).toArray
let decls := curCtorFieldTypes.mapIdx fun idx ty => (.str .anonymous ("arg" ++ idx.1.repr), fun prevArgs => pure (ty.instantiate prevArgs))
let nextRecursorArg ←
Meta.withLocalDeclsD decls fun curCtorFields => do
let recursiveFieldMotiveDecls ← curCtorFields.filterMapM
(fun ctorFieldFVar => do
if (← Meta.inferType ctorFieldFVar) == datatype then return some $ (`_, (fun _ => Meta.mkAppM' motive #[ctorFieldFVar]))
else return none)
Meta.withLocalDeclsD recursiveFieldMotiveDecls fun recursiveFieldMotiveFVars => do
Meta.mkLambdaFVars (curCtorFields ++ recursiveFieldMotiveFVars) $ ← Meta.mkAppOptM `Inhabited.default #[some selectorOutputType, none]
recursorArgs := recursorArgs.push nextRecursorArg
let motives := mutuallyRecursiveDatatypes.map (fun (t, _) => Expr.lam `_ t selectorOutputType .default)
recursorArgs := recursorArgs ++ motives.toArray
let datatypesAndMotives := mutuallyRecursiveDatatypes.zip motives
for (curDatatype, curDatatypeInfo) in mutuallyRecursiveDatatypes do
for curCtorIdx in [:curDatatypeInfo.ctors.length] do
if curDatatype == datatype && curCtorIdx == cval.cidx then
let decls := selCtorFieldTypes.mapIdx fun idx ty => (.str .anonymous ("arg" ++ idx.1.repr), fun prevArgs => pure (ty.instantiate prevArgs))
let nextRecursorArg ←
Meta.withLocalDeclsD decls fun curCtorFields => do
let recursiveFieldMotiveDecls ← curCtorFields.filterMapM
(fun ctorFieldFVar => do
let ctorFieldFVarType ← Meta.inferType ctorFieldFVar
match datatypesAndMotives.find? (fun ((t, _), _) => t == ctorFieldFVarType) with
| none => return none
| some (_, ctorMotive) => return some $ (`_, ((fun _ => Meta.mkAppM' ctorMotive #[ctorFieldFVar]) : Array Expr → MetaM Expr))
)
Meta.withLocalDeclsD recursiveFieldMotiveDecls fun recursiveFieldMotiveFVars => do
Meta.mkLambdaFVars (curCtorFields ++ recursiveFieldMotiveFVars) curCtorFields[argIdx]!
recursorArgs := recursorArgs.push nextRecursorArg
else
let curCtor := mkConst curDatatypeInfo.ctors[curCtorIdx]! lvls
let curCtor ← Meta.mkAppOptM' curCtor (selCtorParams.map some)
let curCtorType ← Meta.inferType curCtor
let curCtorFieldTypes := (getForallArgumentTypes curCtorType).toArray
let decls := curCtorFieldTypes.mapIdx fun idx ty => (.str .anonymous ("arg" ++ idx.1.repr), fun prevArgs => pure (ty.instantiate prevArgs))
let nextRecursorArg ←
Meta.withLocalDeclsD decls fun curCtorFields => do
let recursiveFieldMotiveDecls ← curCtorFields.filterMapM
(fun ctorFieldFVar => do
let ctorFieldFVarType ← Meta.inferType ctorFieldFVar
match datatypesAndMotives.find? (fun ((t, _), _) => t == ctorFieldFVarType) with
| none => return none
| some (_, ctorMotive) => return some $ (`_, ((fun _ => Meta.mkAppM' ctorMotive #[ctorFieldFVar]) : Array Expr → MetaM Expr))
)
Meta.withLocalDeclsD recursiveFieldMotiveDecls fun recursiveFieldMotiveFVars => do
Meta.mkLambdaFVars (curCtorFields ++ recursiveFieldMotiveFVars) $ ← Meta.mkAppOptM `Inhabited.default #[some selectorOutputType, none]
recursorArgs := recursorArgs.push nextRecursorArg
Meta.mkAppOptM' recursor $ recursorArgs.map some

/-- Given the constructor `selCtor` of some inductive datatype and an `argIdx` which is less than `selCtor`'s total number
Expand All @@ -832,38 +845,51 @@ def buildSelectorForUninhabitedType (selCtor : Expr) (argIdx : Nat) : MetaM Expr
let ival ← matchConstInduct datatype.getAppFn
(fun _ => throwError "buildSelectorForUninhabitedType :: The datatype of {selCtor} ({datatype}) is not a datatype")
(fun ival _ => pure ival)
let mutuallyRecursiveDatatypes ← ival.all.mapM
(fun n => do
let nConst ← Meta.mkAppM' (mkConst n lvls) selCtorParams
matchConstInduct nConst.getAppFn
(fun _ => throwError "buildSelectorForUninhabitedType :: Error in gathering InductiveVal for {nConst} which should be mutually recursive with {datatype}")
(fun ival _ => pure (nConst, ival)))
let recursor := (mkConst (.str datatypeName "rec") (selectorOutputUniverseLevel :: lvls))
let mut recursorArgs := selCtorParams
let motive := .lam `_ datatype selectorOutputType .default
-- **TODO** Multiple motives for mutually recursive datatypes
recursorArgs := recursorArgs.push motive
for curCtorIdx in [:ival.ctors.length] do
if curCtorIdx == cval.cidx then
let decls := selCtorFieldTypes.mapIdx fun idx ty => (.str .anonymous ("arg" ++ idx.1.repr), fun prevArgs => pure (ty.instantiate prevArgs))
let nextRecursorArg ←
Meta.withLocalDeclsD decls fun curCtorFields => do
let recursiveFieldMotiveDecls ← curCtorFields.filterMapM
(fun ctorFieldFVar => do
if (← Meta.inferType ctorFieldFVar) == datatype then return some $ (`_, (fun _ => Meta.mkAppM' motive #[ctorFieldFVar]))
else return none)
Meta.withLocalDeclsD recursiveFieldMotiveDecls fun recursiveFieldMotiveFVars => do
Meta.mkLambdaFVars (curCtorFields ++ recursiveFieldMotiveFVars) curCtorFields[argIdx]!
recursorArgs := recursorArgs.push nextRecursorArg
else
let curCtor := mkConst ival.ctors[curCtorIdx]! lvls
let curCtor ← Meta.mkAppOptM' curCtor (selCtorParams.map some)
let curCtorType ← Meta.inferType curCtor
let curCtorFieldTypes := (getForallArgumentTypes curCtorType).toArray
let decls := curCtorFieldTypes.mapIdx fun idx ty => (.str .anonymous ("arg" ++ idx.1.repr), fun prevArgs => pure (ty.instantiate prevArgs))
let nextRecursorArg ←
Meta.withLocalDeclsD decls fun curCtorFields => do
let recursiveFieldMotiveDecls ← curCtorFields.filterMapM
(fun ctorFieldFVar => do
if (← Meta.inferType ctorFieldFVar) == datatype then return some $ (`_, (fun _ => Meta.mkAppM' motive #[ctorFieldFVar]))
else return none)
Meta.withLocalDeclsD recursiveFieldMotiveDecls fun recursiveFieldMotiveFVars => do
Meta.mkLambdaFVars (curCtorFields ++ recursiveFieldMotiveFVars) $ ← Meta.mkSorry selectorOutputType false
recursorArgs := recursorArgs.push nextRecursorArg
let motives := mutuallyRecursiveDatatypes.map (fun (t, _) => Expr.lam `_ t selectorOutputType .default)
recursorArgs := recursorArgs ++ motives.toArray
let datatypesAndMotives := mutuallyRecursiveDatatypes.zip motives
for (curDatatype, curDatatypeInfo) in mutuallyRecursiveDatatypes do
for curCtorIdx in [:curDatatypeInfo.ctors.length] do
if curDatatype == datatype && curCtorIdx == cval.cidx then
let decls := selCtorFieldTypes.mapIdx fun idx ty => (.str .anonymous ("arg" ++ idx.1.repr), fun prevArgs => pure (ty.instantiate prevArgs))
let nextRecursorArg ←
Meta.withLocalDeclsD decls fun curCtorFields => do
let recursiveFieldMotiveDecls ← curCtorFields.filterMapM
(fun ctorFieldFVar => do
let ctorFieldFVarType ← Meta.inferType ctorFieldFVar
match datatypesAndMotives.find? (fun ((t, _), _) => t == ctorFieldFVarType) with
| none => return none
| some (_, ctorMotive) => return some $ (`_, ((fun _ => Meta.mkAppM' ctorMotive #[ctorFieldFVar]) : Array Expr → MetaM Expr))
)
Meta.withLocalDeclsD recursiveFieldMotiveDecls fun recursiveFieldMotiveFVars => do
Meta.mkLambdaFVars (curCtorFields ++ recursiveFieldMotiveFVars) curCtorFields[argIdx]!
recursorArgs := recursorArgs.push nextRecursorArg
else
let curCtor := mkConst curDatatypeInfo.ctors[curCtorIdx]! lvls
let curCtor ← Meta.mkAppOptM' curCtor (selCtorParams.map some)
let curCtorType ← Meta.inferType curCtor
let curCtorFieldTypes := (getForallArgumentTypes curCtorType).toArray
let decls := curCtorFieldTypes.mapIdx fun idx ty => (.str .anonymous ("arg" ++ idx.1.repr), fun prevArgs => pure (ty.instantiate prevArgs))
let nextRecursorArg ←
Meta.withLocalDeclsD decls fun curCtorFields => do
let recursiveFieldMotiveDecls ← curCtorFields.filterMapM
(fun ctorFieldFVar => do
let ctorFieldFVarType ← Meta.inferType ctorFieldFVar
match datatypesAndMotives.find? (fun ((t, _), _) => t == ctorFieldFVarType) with
| none => return none
| some (_, ctorMotive) => return some $ (`_, ((fun _ => Meta.mkAppM' ctorMotive #[ctorFieldFVar]) : Array Expr → MetaM Expr))
)
Meta.withLocalDeclsD recursiveFieldMotiveDecls fun recursiveFieldMotiveFVars => do
Meta.mkLambdaFVars (curCtorFields ++ recursiveFieldMotiveFVars) $ ← Meta.mkSorry selectorOutputType false
recursorArgs := recursorArgs.push nextRecursorArg
Meta.mkAppOptM' recursor $ recursorArgs.map some

/-- Given the constructor `selCtor` of some inductive datatype and an `argIdx` which is less than `selCtor`'s total number
Expand Down

0 comments on commit 6a8c8ff

Please sign in to comment.