Skip to content

Commit

Permalink
Rigid unification option for hint solve/exact
Browse files Browse the repository at this point in the history
This features introduced a new `rigid` flag for `hint
exact/solve`. Lemmas added to a hint with this flag are applied using
a rigid unification algorithm.

Co-Authored-By: Gustavo Delerue <[email protected]>
Co-Authored-By: Pierre-Yves Strub <[email protected]>
  • Loading branch information
Gustavo2622 and strub committed Jan 17, 2025
1 parent 3d69739 commit 10af190
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 102 deletions.
15 changes: 10 additions & 5 deletions src/ecCommands.ml
Original file line number Diff line number Diff line change
Expand Up @@ -264,17 +264,22 @@ module HiPrinting = struct
let ppe0 = EcPrinting.PPEnv.ofenv env in
EcPrinting.pp_by_theory ppe0 (EcPrinting.pp_axiom) fmt ax

(* ------------------------------------------------------------------ *)
let pr_hint_solve (fmt : Format.formatter) (env : EcEnv.env) =
let hint_solve = EcEnv.Auto.all env in
let hint_solve = List.map (fun p ->
(p, EcEnv.Ax.by_path p env)
let hint_solve = List.map (fun (p, mode) ->
let ax = EcEnv.Ax.by_path p env in
(p, (ax, mode))
) hint_solve in

let ppe = EcPrinting.PPEnv.ofenv env in

let pp_hint_solve ppe fmt pax =
Format.fprintf fmt "%a" (EcPrinting.pp_axiom ppe) pax
let pp_hint_solve ppe fmt = (fun (p, (ax, mode)) ->
let mode =
match mode with
| `Default -> ""
| `Rigid -> "(rigid)" in
Format.fprintf fmt "%a %s" (EcPrinting.pp_axiom ppe) (p, ax) mode
)
in

EcPrinting.pp_by_theory ppe pp_hint_solve fmt hint_solve
Expand Down
59 changes: 40 additions & 19 deletions src/ecEnv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ type preenv = {
env_tci : ((ty_params * ty) * tcinstance) list;
env_tc : TC.graph;
env_rwbase : Sp.t Mip.t;
env_atbase : (path list Mint.t) Msym.t;
env_atbase : atbase Msym.t;
env_redbase : mredinfo;
env_ntbase : ntbase Mop.t;
env_modlcs : Sid.t; (* declared modules *)
Expand Down Expand Up @@ -221,6 +221,10 @@ and env_notation = ty_params * EcDecl.notation

and ntbase = (path * env_notation) list

and atbase0 = path * [`Rigid | `Default]

and atbase = atbase0 list Mint.t

(* -------------------------------------------------------------------- *)
type env = preenv

Expand Down Expand Up @@ -1516,39 +1520,53 @@ end
(* -------------------------------------------------------------------- *)
module Auto = struct
type base0 = path * [`Rigid | `Default]
let dname : symbol = ""
let updatedb ~level ?base (ps : path list) (db : (path list Mint.t) Msym.t) =
let updatedb
~(level : int)
?(base : symbol option)
(ps : atbase0 list)
(db : atbase Msym.t)
=
let nbase = (odfl dname base) in
let ps' = Msym.find_def Mint.empty nbase db in
let ps' =
let base = Msym.find_def Mint.empty nbase db in
let levels =
let doit x = Some (ofold (fun x ps -> ps @ x) ps x) in
Mint.change doit level ps' in
Msym.add nbase ps' db
let add ?(import = import0) ~level ?base (ps : path list) lc (env : env) =
Mint.change doit level base in
Msym.add nbase levels db
let add
?(import = import0)
~(level : int)
?(base : symbol option)
(axioms : atbase0 list)
(locality : is_local)
(env : env)
=
let env =
if import.im_immediate then
{ env with
env_atbase = updatedb ?base ~level ps env.env_atbase; }
env_atbase = updatedb ?base ~level axioms env.env_atbase; }
else env
in
{ env with env_item = mkitem import
(Th_auto (level, base, ps, lc)) :: env.env_item; }
(Th_auto { level; base; axioms; locality; }) :: env.env_item; }
let add1 ?import ~level ?base (p : path) lc (env : env) =
let add1 ?import ~level ?base (p : atbase0) lc (env : env) =
add ?import ?base ~level [p] lc env
let get_core ?base (env : env) =
Msym.find_def Mint.empty (odfl dname base) env.env_atbase
let flatten_db (db : path list Mint.t) =
let flatten_db (db : atbase) =
Mint.fold_left (fun ps _ ps' -> ps @ ps') [] db
let get ?base (env : env) =
flatten_db (get_core ?base env)
let getall (bases : symbol list) (env : env) =
let getall (bases : symbol list) (env : env) : atbase0 list =
let dbs = List.map (fun base -> get_core ~base env) bases in
let dbs =
List.fold_left (fun db mi ->
Expand All @@ -1560,7 +1578,7 @@ module Auto = struct
let db = Msym.find_def Mint.empty base env.env_atbase in
Mint.bindings db
let all (env : env) : path list =
let all (env : env) : atbase0 list =
Msym.values env.env_atbase |> List.map flatten_db |> List.flatten
end
Expand Down Expand Up @@ -2951,8 +2969,8 @@ module Theory = struct
(* ------------------------------------------------------------------ *)
let bind_at_th =
let for1 _path db = function
| Th_auto (level, base, ps, _) ->
Some (Auto.updatedb ?base ~level ps db)
| Th_auto {level; base; axioms; _} ->
Some (Auto.updatedb ?base ~level axioms db)
| _ -> None
in bind_base_th for1
Expand Down Expand Up @@ -3125,9 +3143,12 @@ module Theory = struct
let ps = List.filter ((not) |- inclear |- oget |- EcPath.prefix) ps in
if List.is_empty ps then None else Some (Th_addrw (p, ps,lc))
| Th_auto (lvl, base, ps, lc) ->
let ps = List.filter ((not) |- inclear |- oget |- EcPath.prefix) ps in
if List.is_empty ps then None else Some (Th_auto (lvl, base, ps, lc))
| Th_auto ({ axioms } as auto_rl) ->
let axioms = List.filter (fun (p, _) ->
let p = oget (EcPath.prefix p) in
not (inclear p)
) axioms in
if List.is_empty axioms then None else Some (Th_auto {auto_rl with axioms})
| (Th_export (p, _)) as item ->
if Sp.mem p cleared then None else Some item
Expand Down
14 changes: 8 additions & 6 deletions src/ecEnv.mli
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,15 @@ end

(* -------------------------------------------------------------------- *)
module Auto : sig
type base0 = path * [`Rigid | `Default]

val dname : symbol
val add1 : ?import:import -> level:int -> ?base:symbol -> path -> is_local -> env -> env
val add : ?import:import -> level:int -> ?base:symbol -> path list -> is_local -> env -> env
val get : ?base:symbol -> env -> path list
val getall : symbol list -> env -> path list
val getx : symbol -> env -> (int * path list) list
val all : env -> path list
val add1 : ?import:import -> level:int -> ?base:symbol -> base0 -> is_local -> env -> env
val add : ?import:import -> level:int -> ?base:symbol -> base0 list -> is_local -> env -> env
val get : ?base:symbol -> env -> base0 list
val getall : symbol list -> env -> base0 list
val getx : symbol -> env -> (int * base0 list) list
val all : env -> base0 list
end

(* -------------------------------------------------------------------- *)
Expand Down
33 changes: 19 additions & 14 deletions src/ecLowGoal.ml
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ module Apply = struct

exception NoInstance of (bool * reason * PT.pt_env * (form * form))

let t_apply_bwd_r ?(mode = fmdelta) ?(canview = true) pt (tc : tcenv1) =
let t_apply_bwd_r ?(ri = EcReduction.full_compat) ?(mode = fmdelta) ?(canview = true) pt (tc : tcenv1) =
let ((hyps, concl), pterr) = (FApi.tc1_flat tc, PT.copy pt.ptev_env) in

let noinstance ?(dpe = false) reason =
Expand All @@ -736,7 +736,7 @@ module Apply = struct
match istop && PT.can_concretize pt.PT.ptev_env with
| true ->
let ax = PT.concretize_form pt.PT.ptev_env pt.PT.ptev_ax in
if EcReduction.is_conv ~ri:EcReduction.full_compat hyps ax concl
if EcReduction.is_conv ~ri hyps ax concl
then pt
else instantiate canview false pt

Expand All @@ -747,7 +747,7 @@ module Apply = struct
noinstance `IncompleteInference;
pt
with EcMatching.MatchFailure ->
match TTC.destruct_product hyps pt.PT.ptev_ax with
match TTC.destruct_product ~reduce:(mode.fm_conv) hyps pt.PT.ptev_ax with
| Some _ ->
(* FIXME: add internal marker *)
instantiate canview false (PT.apply_pterm_to_hole pt)
Expand Down Expand Up @@ -800,15 +800,15 @@ module Apply = struct

t_apply pt tc

let t_apply_bwd ?mode ?canview pt (tc : tcenv1) =
let t_apply_bwd ?(ri : EcReduction.reduction_info option) ?mode ?canview pt (tc : tcenv1) =
let hyps = FApi.tc1_hyps tc in
let pt, ax = LowApply.check `Elim pt (`Hyps (hyps, !!tc)) in
let ptenv = ptenv_of_penv hyps !!tc in
let pt = { ptev_env = ptenv; ptev_pt = pt; ptev_ax = ax; } in
t_apply_bwd_r ?mode ?canview pt tc
t_apply_bwd_r ?ri ?mode ?canview pt tc

let t_apply_bwd_hi ?(dpe = true) ?mode ?canview pt (tc : tcenv1) =
try t_apply_bwd ?mode ?canview pt tc
let t_apply_bwd_hi ?(ri : EcReduction.reduction_info option) ?(dpe = true) ?mode ?canview pt (tc : tcenv1) =
try t_apply_bwd ?ri ?mode ?canview pt tc
with (NoInstance (_, r, pt, f)) ->
tc_error_exn !!tc (NoInstance (dpe, r, pt, f))
end
Expand Down Expand Up @@ -2582,22 +2582,27 @@ let t_coq
let t_solve ?(canfail = true) ?(bases = [EcEnv.Auto.dname]) ?(mode = fmdelta) ?(depth = 1) (tc : tcenv1) =
let bases = EcEnv.Auto.getall bases (FApi.tc1_env tc) in

let t_apply1 p tc =

let t_apply1 ((p, rigid): Auto.base0) tc =
let ri, mode =
match rigid with
| `Rigid -> EcReduction.no_red, fmsearch
| `Default -> EcReduction.full_compat, mode in
let pt = PT.pt_of_uglobal !!tc (FApi.tc1_hyps tc) p in
try
Apply.t_apply_bwd_r ~mode ~canview:false pt tc
with Apply.NoInstance _ -> t_fail tc in
Apply.t_apply_bwd_r ~ri ~mode ~canview:false pt tc
with Apply.NoInstance _ ->
t_fail tc
in

let rec t_apply ctn p tc =
let rec t_apply ctn ip tc =
if ctn > depth
then t_fail tc
else (t_apply1 p @! t_trivial @! t_solve (ctn + 1) bases) tc
else (t_apply1 ip @! t_trivial @! t_solve (ctn + 1) bases) tc

and t_solve ctn bases tc =
match bases with
| [] -> t_abort tc
| p::bases -> (FApi.t_or (t_apply ctn p) (t_solve ctn bases)) tc in
| ip::bases -> (FApi.t_or (t_apply ctn ip) (t_solve ctn bases)) tc in

let t = t_solve 0 bases in
let t = if canfail then FApi.t_try t else t in
Expand Down
7 changes: 3 additions & 4 deletions src/ecLowGoal.mli
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,13 @@ module Apply : sig
exception NoInstance of (bool * reason * pt_env * (form * form))

val t_apply_bwd_r :
?mode:fmoptions -> ?canview:bool -> pt_ev -> FApi.backward
?ri:EcReduction.reduction_info -> ?mode:fmoptions -> ?canview:bool -> pt_ev -> FApi.backward

val t_apply_bwd :
?mode:fmoptions -> ?canview:bool -> proofterm -> FApi.backward
?ri:EcReduction.reduction_info -> ?mode:fmoptions -> ?canview:bool -> proofterm -> FApi.backward

val t_apply_bwd_hi:
?dpe:bool -> ?mode:fmoptions -> ?canview:bool
-> proofterm -> FApi.backward
?ri:EcReduction.reduction_info -> ?dpe:bool -> ?mode:fmoptions -> ?canview:bool -> proofterm -> FApi.backward
end

(* -------------------------------------------------------------------- *)
Expand Down
25 changes: 18 additions & 7 deletions src/ecParser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -3699,14 +3699,25 @@ addrw:
| local=is_local HINT REWRITE p=lqident COLON l=lqident*
{ (local, p, l) }

hint:
| local=is_local HINT EXACT base=lident? COLON l=qident*
{ { ht_local = local; ht_prio = 0;
ht_base = base ; ht_names = l; } }
hintoption:
| x=lident {
match unloc x with
| "rigid" -> `Rigid
| _ ->
parse_error x.pl_loc
(Some ("invalid option: " ^ (unloc x)))
}

| local=is_local HINT SOLVE i=word base=lident? COLON l=qident*
{ { ht_local = local; ht_prio = i;
ht_base = base ; ht_names = l; } }
hint:
| local=is_local
HINT opts=ioption(bracket(hintoption)+)
prio=ID(EXACT { 0 } | SOLVE i=word { i })
base=lident? COLON l=qident*
{ { ht_local = local;
ht_prio = prio;
ht_base = base ;
ht_names = l;
ht_options = odfl [] opts; } }

(* -------------------------------------------------------------------- *)
(* User reduction *)
Expand Down
12 changes: 8 additions & 4 deletions src/ecParsetree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1228,12 +1228,16 @@ type save = [ `Qed | `Admit | `Abort ]
(* -------------------------------------------------------------------- *)
type theory_clear = (pqsymbol option) list

(* -------------------------------------------------------------------- *)
type phintoption = [ `Rigid ]

(* -------------------------------------------------------------------- *)
type phint = {
ht_local : is_local;
ht_prio : int;
ht_base : psymbol option;
ht_names : pqsymbol list;
ht_local : is_local;
ht_prio : int;
ht_base : psymbol option;
ht_names : pqsymbol list;
ht_options : phintoption list;
}

(* -------------------------------------------------------------------- *)
Expand Down
35 changes: 25 additions & 10 deletions src/ecPrinting.ml
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,13 @@ let pp_rwname ppe fmt p =
let pp_axname ppe fmt p =
Format.fprintf fmt "%a" EcSymbols.pp_qsymbol (PPEnv.ax_symb ppe p)

let pp_axhnt ppe fmt (p, b) =
let b =
match b with
| `Default -> ""
| `Rigid -> " (rigid)" in
Format.fprintf fmt "%a%s" (pp_axname ppe) p b

(* -------------------------------------------------------------------- *)
let pp_thname ppe fmt p =
EcSymbols.pp_qsymbol fmt (PPEnv.th_symb ppe p)
Expand Down Expand Up @@ -3020,23 +3027,31 @@ let pp_rwbase ppe fmt (p, rws) =
(pp_rwname ppe) p (pp_list ", " (pp_axname ppe)) (Sp.elements rws)

(* -------------------------------------------------------------------- *)
let pp_solvedb ppe fmt db =
let pp_solvedb ppe fmt (db: (int * (P.path * _) list) list) =
List.iter (fun (lvl, ps) ->
Format.fprintf fmt "[%3d] %a\n%!"
lvl (pp_list ", " (pp_axname ppe)) ps)
lvl
(pp_list ", " (pp_axhnt ppe))
ps)
db;

let lemmas = List.flatten (List.map snd db) in
let lemmas = List.pmap (fun p ->
let lemmas = List.pmap (fun (p, ir) ->
let ax = EcEnv.Ax.by_path_opt p ppe.PPEnv.ppe_env in
(omap (fun ax -> (p, ax)) ax))
lemmas
(omap (fun ax -> (ir, (p, ax))) ax)
) lemmas
in

if not (List.is_empty lemmas) then begin
Format.fprintf fmt "\n%!";
List.iter
(fun ax -> Format.fprintf fmt "%a\n\n%!" (pp_axiom ppe) ax)
(fun (ir, ax) ->
let ir =
match ir with
| `Default -> ""
| `Rigid -> " (rigid)" in

Format.fprintf fmt "%a%s\n\n%!" (pp_axiom ppe) ax ir)
lemmas
end

Expand Down Expand Up @@ -3526,11 +3541,11 @@ let rec pp_theory ppe (fmt : Format.formatter) (path, cth) =
(* FIXME: section we should add the lemma in the reduction *)
Format.fprintf fmt "hint simplify."

| EcTheory.Th_auto (lvl, base, p, lc) ->
| EcTheory.Th_auto { level; base; axioms; locality; } ->
Format.fprintf fmt "%ahint solve %d %s : %a."
pp_locality lc
lvl (odfl "" base)
(pp_list "@ " (pp_axname ppe)) p
pp_locality locality
level (odfl "" base)
(pp_list "@ " (pp_axhnt ppe)) axioms

(* -------------------------------------------------------------------- *)
let pp_stmt_with_nums (ppe : PPEnv.t) fmt stmt =
Expand Down
Loading

0 comments on commit 10af190

Please sign in to comment.