Source file wp_rewrite.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
open CErrors
open Constr
open Equality
open Genarg
open Ltac_plugin.Tacarg
open Locus
open Names
open Pp
open Proofview
open Proofview.Notations
open Util
open Backtracking
open Proofutils
type raw_rew_rule = (Constr.t Univ.in_universe_context_set * bool * raw_generic_argument option) CAst.t
(** Rewriting rules *)
type rew_rule = {
  rew_id : KerName.t;
  rew_lemma : constr;
  rew_type: types;
  rew_pat: constr;
  rew_ctx: Univ.ContextSet.t;
  rew_l2r: bool;
  rew_tac: Genarg.glob_generic_argument option
}
module HintIdent = struct
  type t = rew_rule
  let compare r1 r2 = KerName.compare r1.rew_id r2.rew_id
end
(**
  Representation/approximation of terms to use in the dnet:
    - no meta or evar (use ['a pattern] for that)
    - [Rel]s and [Sort]s are not taken into account (that's why we need a second pass of linear filterin on the results - it's not a perfect term indexing structure)
*)
module DTerm =
struct
  type 't t =
    | DRel
    | DSort
    | DRef    of GlobRef.t
    | DProd
    | DLet
    | DLambda
    | DApp
    | DCase   of case_info
    | DFix    of int array * int
    | DCoFix  of int
    | DInt    of Uint63.t
    | DFloat  of Float64.t
    | DArray
  let compare_ci ci1 ci2 =
    let c = Ind.CanOrd.compare ci1.ci_ind ci2.ci_ind in
    if c = 0 then
      let c = Int.compare ci1.ci_npar ci2.ci_npar in
      if c = 0 then
        let c = Array.compare Int.compare ci1.ci_cstr_ndecls ci2.ci_cstr_ndecls in
        if c = 0 then
          Array.compare Int.compare ci1.ci_cstr_nargs ci2.ci_cstr_nargs
        else c
      else c
    else c
  let compare t1 t2 = match t1, t2 with
    | DRel, DRel -> 0
    | DRel, _ -> -1 | _, DRel -> 1
    | DSort, DSort -> 0
    | DSort, _ -> -1 | _, DSort -> 1
    | DRef gr1, DRef gr2 -> GlobRef.CanOrd.compare gr1 gr2
    | DRef _, _ -> -1 | _, DRef _ -> 1
    | DProd, DProd -> 0
    | DProd, _ -> -1 | _, DProd -> 1
    | DLet, DLet -> 0
    | DLet, _ -> -1 | _, DLet -> 1
    | DLambda, DLambda
    | DApp, DApp -> 0
    | DLambda, _ -> -1 | _, DLambda -> 1
    | DApp, _ -> -1 | _, DApp -> 1
    | DCase ci1, DCase ci2 -> compare_ci ci1 ci2
    | DCase _, _ -> -1 | _, DCase _ -> 1
    | DFix (i1, j1), DFix (i2, j2) ->
      let c = Int.compare j1 j2 in
      if c = 0 then
        Array.compare Int.compare i1 i2
      else c
    | DFix _, _ -> -1 | _, DFix _ -> 1
    | DCoFix i1, DCoFix i2 -> Int.compare i1 i2
    | DCoFix _, _ -> -1 | _, DCoFix _ -> 1
    | DInt i1, DInt i2 -> Uint63.compare i1 i2
    | DInt _, _ -> -1 | _, DInt _ -> 1
    | DFloat f1, DFloat f2 -> Float64.total_compare f1 f2
    | DFloat _, _ -> -1 | _, DFloat _ -> 1
    | DArray, DArray -> 1
end
(**
  Terms discrimination nets
  
  Uses the general dnet datatype on DTerm.t (here you can restart reading)
*)
module HintDN :
sig
  type t
  type ident = HintIdent.t
  val empty : t
  (** [add c i dn] adds the binding [(c,i)] to [dn]. [c] can be a
     closed term or a pattern (with untyped Evars). No Metas accepted *)
  val add : constr -> ident -> t -> t
  (** [find_all dn] returns all idents contained in dn *)
  val find_all : t -> ident list
end = struct
  module Ident = HintIdent
  module PTerm =
  struct
    type t = unit DTerm.t
    let compare = DTerm.compare
  end
  module TDnet = Dn.Make(PTerm)(Ident)
  open DTerm
  type t = TDnet.t
  type ident = HintIdent.t
  let pat_of_constr c : (unit DTerm.t * Constr.t list) option =
    let open GlobRef in
    let rec pat_of_constr c = match Constr.kind c with
      | Rel _ -> Some (DRel, [])
      | Sort _ -> Some (DSort, [])
      | Var i -> Some (DRef (VarRef i), [])
      | Const (c,u) -> Some (DRef (ConstRef c), [])
      | Ind (i,u) -> Some (DRef (IndRef i), [])
      | Construct (c,u) -> Some (DRef (ConstructRef c), [])
      | Meta _ -> assert false
      | Evar (i,_) -> None
      | Case (ci,u1,pms1,c1,_iv,c2,ca) -> Some (DCase(ci), [snd c1; c2] @ Array.map_to_list snd ca)
      | Fix ((ia,i),(_,ta,ca)) -> Some (DFix(ia,i), Array.to_list ta @ Array.to_list ca)
      | CoFix (i,(_,ta,ca)) -> Some (DCoFix(i), Array.to_list ta @ Array.to_list ca)
      | Cast (c,_,_) -> pat_of_constr c
      | Lambda (_,t,c) -> Some (DLambda, [t; c])
      | Prod (_, t, u) -> Some (DProd, [t; u])
      | LetIn (_, c, t, u) -> Some (DLet, [c; t; u])
      | App (f,ca) ->
        let len = Array.length ca in
        let a = ca.(len - 1) in
        let ca = Array.sub ca 0 (len - 1) in
        Some (DApp, [mkApp (f, ca); a])
      | Proj (p,c) -> pat_of_constr @@ mkApp (mkConst @@ Projection.constant p, [|c|])
      | Int i -> Some (DInt i, [])
      | Float f -> Some (DFloat f, [])
      | Array (_u,t,def,ty) -> Some (DArray, Array.to_list t @ [def ; ty])
    in pat_of_constr c
  
  let empty = TDnet.empty
  let add (c:constr) (id:Ident.t) (dn:t) =
    let (ctx, c) = Term.decompose_prod_assum c in
    let c = TDnet.pattern pat_of_constr c in
    TDnet.add dn c id
  let find_all dn = TDnet.lookup dn (fun () -> Everything) ()
end
(** Type of rewrite databases *)
type rewrite_db = {
  rdb_hintdn : HintDN.t;
  rdb_order : int KNmap.t;
  rdb_maxuid : int;
}
type hypinfo = {
  hyp_ty : EConstr.types;
  hyp_pat : EConstr.constr;
}
(** Empty rewrite database *)
let empty_rewrite_db = {
  rdb_hintdn = HintDN.empty;
  rdb_order = KNmap.empty;
  rdb_maxuid = 0;
}
let fresh_key: unit -> KerName.t =
  let id = ref 0 in
  fun () ->
    let cur = incr id; !id in
    let lbl = Id.of_string ("_" ^ string_of_int cur) in
    let kn = Lib.make_kn lbl in
    let (mp, _) = KerName.repr kn in
    let lbl = Id.of_string_soft (Printf.sprintf "%s#%i"
      (ModPath.to_string mp) cur)
    in KerName.make mp (Label.of_id lbl)
let decompose_applied_relation (env: Environ.env) (sigma: Evd.evar_map) (c: constr) (ctype: Evd.econstr) (left2right: bool): hypinfo option =
  let find_rel ty =
    let sigma, ty = EClause.make_evar_clause env sigma ty in
    let (_, args) = EConstr.decompose_app sigma ty.EClause.cl_concl in
    let len = Array.length args in
    if 2 <= len then
      let c1 = args.(len - 2) in
      let c2 = args.(len - 1) in
      Some (if left2right then c1 else c2)
    else None
  in match find_rel ctype with
    | Some c -> Some { hyp_pat = c; hyp_ty = ctype }
    | None ->
        let ctx,t' = Reductionops.splay_prod_assum env sigma ctype in 
        let ctype = EConstr.it_mkProd_or_LetIn t' ctx in
        match find_rel ctype with
        | Some c -> Some { hyp_pat = c; hyp_ty = ctype }
        | None -> None
  
let add_rew_rules (rewrite_database: rewrite_db) (rew_rules: rew_rule list): rewrite_db =
  List.fold_left (fun accu r -> {
    rdb_hintdn = HintDN.add r.rew_pat r accu.rdb_hintdn;
    rdb_order = KNmap.add r.rew_id accu.rdb_maxuid accu.rdb_order;
    rdb_maxuid = accu.rdb_maxuid + 1;
  }) rewrite_database rew_rules
module RewriteDatabase: Mergeable with type elt = rewrite_db = struct
  type elt = rewrite_db
  let empty = empty_rewrite_db
  let merge rewrite_db1 rewrite_db2 = add_rew_rules rewrite_db1 (HintDN.find_all rewrite_db2.rdb_hintdn)
end
module RewriteDatabaseTactics = TypedTactics(RewriteDatabase)
let find_rewrites (rewrite_database: rewrite_db): rew_rule list =
  let sort r1 r2 = Int.compare (KNmap.find r2.rew_id rewrite_database.rdb_order) (KNmap.find r1.rew_id rewrite_database.rdb_order) in
  List.sort sort (HintDN.find_all rewrite_database.rdb_hintdn)
(** Applies all the rules of one hint rewrite database *)
let one_base (where: variable option) (tactic: trace tactic) (rewrite_database: rewrite_db): unit tactic =
  let rew_rules = find_rewrites rewrite_database in
  let rewrite (dir: bool) (c: constr) (tac: unit tactic): unit tactic =
    let c = (EConstr.of_constr c, Tactypes.NoBindings) in
    general_rewrite ~where ~l2r:dir AllOccurrences ~freeze:true ~dep:false ~with_evars:false ~tac:(tac, AllMatches) c
  in
  let try_rewrite (rule: rew_rule) (tac: unit tactic): unit tactic =
    Proofview.Goal.enter begin fun gl ->
      let sigma = Proofview.Goal.sigma gl in
      let subst, ctx' = UnivGen.fresh_universe_context_set_instance rule.rew_ctx in
      let c' = Vars.subst_univs_level_constr subst rule.rew_lemma in
      let sigma = Evd.merge_context_set Evd.univ_flexible sigma ctx' in
      Proofview.tclTHEN (Proofview.Unsafe.tclEVARS sigma) (rewrite rule.rew_l2r c' tac)
    end
  in
  let eval (rule: rew_rule) =
    let tac = match rule.rew_tac with
      | None -> Proofview.tclUNIT ()
      | Some (Genarg.GenArg (Genarg.Glbwit wit, tac)) ->
        let ist = {
          Geninterp.lfun = Id.Map.empty;
          poly = false;
          extra = Geninterp.TacStore.empty
        } in Ftactic.run (Geninterp.interp wit ist tac) (fun _ -> Proofview.tclUNIT ())
    in Tacticals.tclREPEAT_MAIN (tclTHEN (try_rewrite rule tac) (tclIGNORE tactic))
  in
  let rules = tclMAP_rev eval rew_rules in
  Tacticals.tclREPEAT_MAIN @@ Proofview.tclPROGRESS rules
(** The [autorewrite] tactic *)
let autorewrite (tac: trace tactic) (rewrite_database: rewrite_db): unit tactic =
  Tacticals.tclREPEAT_MAIN (
    Proofview.tclPROGRESS @@ one_base None tac rewrite_database
  )
let autorewrite_multi_in (idl: variable list) (tac: trace tactic) (rewrite_database: rewrite_db): unit tactic =
  Proofview.Goal.enter begin fun gl ->
    Tacticals.tclMAP (fun id ->
      Tacticals.tclREPEAT_MAIN (
        Proofview.tclPROGRESS @@
        one_base (Some id) tac rewrite_database
      )
    ) idl
  end
let try_do_hyps (treat_id: 'a -> variable) (l: 'a list): trace tactic -> rewrite_db -> unit tactic =
  autorewrite_multi_in (List.map treat_id l)
let gen_auto_multi_rewrite (tac: trace tactic) (cl: clause) (rewrite_tab: rewrite_db): unit tactic =
  let concl_tac = (if cl.concl_occs != NoOccurrences then autorewrite tac rewrite_tab else Proofview.tclUNIT ()) in
  if not @@ Locusops.is_all_occurrences cl.concl_occs && cl.concl_occs != NoOccurrences
    then Tacticals.tclZEROMSG ~info:(Exninfo.reify ()) (str"The \"at\" syntax isn't available yet for the autorewrite tactic.")
    else match cl.onhyps with
      | Some [] -> concl_tac
      | Some l -> Tacticals.tclTHENFIRST concl_tac (try_do_hyps (fun ((_,id),_) -> id) l tac rewrite_tab)
      | None ->
        let hyp_tac =
          Proofview.Goal.enter begin fun gl ->
            let ids = Tacmach.pf_ids_of_hyps gl in
            try_do_hyps (fun id -> id)  ids tac rewrite_tab
          end
        in Tacticals.tclTHENFIRST concl_tac hyp_tac
let find_applied_relation ?(loc: Loc.t option) (env: Environ.env) sigma c left2right =
  let ctype = Retyping.get_type_of env sigma (EConstr.of_constr c) in
  match decompose_applied_relation env sigma c ctype left2right with
    | Some c -> c
    | None ->
      user_err ?loc (
        str "The type " ++
        Printer.pr_econstr_env env sigma ctype ++
        str " of this term does not end with an applied relation."
      )
let fill_rewrite_tab (env: Environ.env) (sigma: Evd.evar_map) (rule : raw_rew_rule) (rewrite_database: rewrite_db): rewrite_db =
  let ist = Genintern.empty_glob_sign ~strict:true env in
  
  let intern (tac: raw_generic_argument): glob_generic_argument = snd (Genintern.generic_intern ist tac) in
  
  let to_rew_rule ({CAst.loc;v=((c,ctx),b,t)}: raw_rew_rule): rew_rule =
    let sigma = Evd.merge_context_set Evd.univ_rigid sigma ctx in
    let info = find_applied_relation ?loc env sigma c b in
    let pat = EConstr.Unsafe.to_constr info.hyp_pat in
    let uid = fresh_key () in {
      rew_id = uid;
      rew_lemma = c;
      rew_type = EConstr.Unsafe.to_constr info.hyp_ty;
      rew_pat = pat;
      rew_ctx = ctx;
      rew_l2r = b;
      rew_tac = Option.map intern t
    }
  in
  add_rew_rules rewrite_database [to_rew_rule rule]
(** Prints the current rewrite hint database *)
let print_rewrite_hintdb (env: Environ.env) (sigma: Evd.evar_map) (rewrite_database: rewrite_db) =
  str "Local rewrite database" ++
  fnl () ++
  prlist_with_sep fnl (fun h ->
    str (if h.rew_l2r then "rewrite -> " else "rewrite <- ") ++
    Printer.pr_lconstr_env env sigma h.rew_lemma ++ str " of type " ++ Printer.pr_lconstr_env env sigma h.rew_type ++
    Option.cata (fun tac -> str " then use tactic " ++
    Pputils.pr_glb_generic env sigma tac) (mt ()) h.rew_tac
  ) (find_rewrites rewrite_database)
(**
  Converts a given hypothesis into a raw rule than can be added to the hint rewrite database    
*)
let to_raw_rew_rule (env: Environ.env) (sigma: Evd.evar_map) (hyp: Constrexpr.constr_expr): raw_rew_rule =
  let econstr, context = Constrintern.interp_constr env sigma hyp in
  let constr = EConstr.to_constr sigma econstr in
  let univ_ctx = UState.context_set context in
  let ctx = (DeclareUctx.declare_universe_context ~poly:false univ_ctx; Univ.ContextSet.empty) in
  CAst.make ?loc:(Constrexpr_ops.constr_loc hyp) ((constr, ctx), true, Option.map (in_gen (rawwit wit_ltac)) None)
(**  
  This function will add in the rewrite hint database "core" every hint possible created from the hypothesis
*)
let fill_local_rewrite_database (): rewrite_db tactic =
  RewriteDatabaseTactics.typedGoalEnter @@ fun goal ->
    let env = Goal.env goal in
    let sigma = Goal.sigma goal in
    let hyps = List.map (fun decl ->
      Constrexpr_ops.mkIdentC @@ Context.Named.Declaration.get_id decl
    ) (Goal.hyps goal) in
    let new_rules = List.map (to_raw_rew_rule env sigma) hyps in
    tclUNIT @@ List.fold_left (fun acc rule ->
      try
        fill_rewrite_tab env sigma rule acc
      with _ -> acc
    ) RewriteDatabase.empty new_rules
(**
  Waterproof autorewrite
  This tactic is a rewrite of the coq-core's [autorewrite] tactic that will only consider current hypothesis as rewrite hints.
*)
let wp_autorewrite ?(print_hints: bool = false) (log: bool) (tac: trace tactic): unit tactic =
  let clause = {onhyps = Some []; concl_occs = Locus.AllOccurrences} in
  fill_local_rewrite_database () >>= fun rewrite_db ->
    Goal.enter @@ begin fun goal ->
    let env = Goal.env goal in
    let sigma = Goal.sigma goal in
    if print_hints then Feedback.msg_notice @@ print_rewrite_hintdb env sigma rewrite_db;
    if log then Feedback.msg_notice @@ str "(* application of wp_autorewrite *)";
    Tacticals.tclREPEAT @@ tclPROGRESS @@ gen_auto_multi_rewrite tac clause rewrite_db
  end >>= fun _ -> tclUNIT ()