Source file specialize.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
open Ast
open Ast_defs
open Ast_util
open Rewriter
let opt_ddump_spec_ast = ref None
let is_typ_arg = function A_aux (A_typ _, _) -> true | _ -> false
type specialization = {
is_polymorphic : kinded_id -> bool;
instantiation_filter : kid -> typ_arg -> bool;
extern_filter : extern option -> bool;
}
let typ_specialization =
{
is_polymorphic = (fun kopt -> is_typ_kopt kopt);
instantiation_filter = (fun _ -> is_typ_arg);
extern_filter = (fun _ -> false);
}
let int_specialization =
{
is_polymorphic = is_int_kopt;
instantiation_filter =
(fun _ arg -> match arg with A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _) -> true | _ -> false);
extern_filter = (fun externs -> match Ast_util.extern_assoc "c" externs with Some _ -> true | None -> false);
}
let int_specialization_with_externs =
{
is_polymorphic = is_int_kopt;
instantiation_filter =
(fun _ arg -> match arg with A_aux (A_nexp (Nexp_aux (Nexp_constant _, _)), _) -> true | _ -> false);
extern_filter = (fun _ -> false);
}
let rec nexp_simp_typ (Typ_aux (typ_aux, l)) =
let typ_aux =
match typ_aux with
| Typ_id v -> Typ_id v
| Typ_var kid -> Typ_var kid
| Typ_tuple typs -> Typ_tuple (List.map nexp_simp_typ typs)
| Typ_app (f, args) -> Typ_app (f, List.map nexp_simp_typ_arg args)
| Typ_exist (kids, nc, typ) -> Typ_exist (kids, nc, nexp_simp_typ typ)
| Typ_fn (arg_typs, ret_typ) -> Typ_fn (List.map nexp_simp_typ arg_typs, nexp_simp_typ ret_typ)
| Typ_bidir (t1, t2) -> Typ_bidir (nexp_simp_typ t1, nexp_simp_typ t2)
| Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown"
in
Typ_aux (typ_aux, l)
and nexp_simp_typ_arg (A_aux (typ_arg_aux, l)) =
let typ_arg_aux =
match typ_arg_aux with
| A_nexp n -> A_nexp (nexp_simp n)
| A_typ typ -> A_typ (nexp_simp_typ typ)
| A_bool nc -> A_bool (constraint_simp nc)
in
A_aux (typ_arg_aux, l)
let fix_instantiation spec instantiation =
let instantiation = KBindings.bindings (KBindings.filter spec.instantiation_filter instantiation) in
let instantiation = List.map (fun (kid, arg) -> (Type_check.orig_kid kid, nexp_simp_typ_arg arg)) instantiation in
List.fold_left (fun m (k, v) -> KBindings.add k v m) KBindings.empty instantiation
let rec polymorphic_functions ctx defs =
match defs with
| DEF_aux (DEF_val (VS_aux (VS_val_spec (TypSchm_aux (TypSchm_ts (typq, typ), _), id, externs), _)), _) :: defs ->
let is_polymorphic = List.exists ctx.is_polymorphic (quant_kopts typq) in
if is_polymorphic && not (ctx.extern_filter externs) then IdSet.add id (polymorphic_functions ctx defs)
else polymorphic_functions ctx defs
| _ :: defs -> polymorphic_functions ctx defs
| [] -> IdSet.empty
let string_of_instantiation instantiation =
let open Type_check in
let kid_names = ref KOptMap.empty in
let kid_counter = ref 0 in
let kid_name kid =
try KOptMap.find kid !kid_names
with Not_found ->
let n = string_of_int !kid_counter in
kid_names := KOptMap.add kid n !kid_names;
incr kid_counter;
n
in
let rec string_of_nexp = function Nexp_aux (nexp, _) -> string_of_nexp_aux nexp
and string_of_nexp_aux = function
| Nexp_id id -> string_of_id id
| Nexp_var kid -> kid_name (mk_kopt K_int kid)
| Nexp_constant c -> Big_int.to_string c
| Nexp_times (n1, n2) -> "(" ^ string_of_nexp n1 ^ " * " ^ string_of_nexp n2 ^ ")"
| Nexp_sum (n1, n2) -> "(" ^ string_of_nexp n1 ^ " + " ^ string_of_nexp n2 ^ ")"
| Nexp_minus (n1, n2) -> "(" ^ string_of_nexp n1 ^ " - " ^ string_of_nexp n2 ^ ")"
| Nexp_app (id, nexps) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_nexp nexps ^ ")"
| Nexp_exp n -> "2 ^ " ^ string_of_nexp n
| Nexp_neg n -> "- " ^ string_of_nexp n
| Nexp_if (i, t, e) ->
"(if " ^ string_of_n_constraint i ^ " then " ^ string_of_nexp t ^ " else " ^ string_of_nexp e ^ ")"
and string_of_typ = function Typ_aux (typ, l) -> string_of_typ_aux typ
and string_of_typ_aux = function
| Typ_id id -> string_of_id id
| Typ_var kid -> kid_name (mk_kopt K_type kid)
| Typ_tuple typs -> "(" ^ Util.string_of_list ", " string_of_typ typs ^ ")"
| Typ_app (id, args) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")"
| Typ_fn (arg_typs, ret_typ) ->
"(" ^ Util.string_of_list ", " string_of_typ arg_typs ^ ") -> " ^ string_of_typ ret_typ
| Typ_bidir (t1, t2) -> string_of_typ t1 ^ " <-> " ^ string_of_typ t2
| Typ_exist (kids, nc, typ) ->
"exist " ^ Util.string_of_list " " kid_name kids ^ ", " ^ string_of_n_constraint nc ^ ". " ^ string_of_typ typ
| Typ_internal_unknown -> "UNKNOWN"
and string_of_typ_arg = function A_aux (typ_arg, l) -> string_of_typ_arg_aux typ_arg
and string_of_typ_arg_aux = function
| A_nexp n -> string_of_nexp n
| A_typ typ -> string_of_typ typ
| A_bool nc -> string_of_n_constraint nc
and string_of_n_constraint = function
| NC_aux (NC_id id, _) -> string_of_id id
| NC_aux (NC_equal (t1, t2), _) -> string_of_typ_arg t1 ^ " == " ^ string_of_typ_arg t2
| NC_aux (NC_not_equal (t1, t2), _) -> string_of_typ_arg t1 ^ " != " ^ string_of_typ_arg t2
| NC_aux (NC_ge (n1, n2), _) -> string_of_nexp n1 ^ " >= " ^ string_of_nexp n2
| NC_aux (NC_gt (n1, n2), _) -> string_of_nexp n1 ^ " > " ^ string_of_nexp n2
| NC_aux (NC_le (n1, n2), _) -> string_of_nexp n1 ^ " <= " ^ string_of_nexp n2
| NC_aux (NC_lt (n1, n2), _) -> string_of_nexp n1 ^ " < " ^ string_of_nexp n2
| NC_aux (NC_or (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " | " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_and (nc1, nc2), _) -> "(" ^ string_of_n_constraint nc1 ^ " & " ^ string_of_n_constraint nc2 ^ ")"
| NC_aux (NC_set (n, ns), _) -> string_of_nexp n ^ " in {" ^ Util.string_of_list ", " Big_int.to_string ns ^ "}"
| NC_aux (NC_true, _) -> "true"
| NC_aux (NC_false, _) -> "false"
| NC_aux (NC_var kid, _) -> kid_name (mk_kopt K_bool kid)
| NC_aux (NC_app (id, args), _) -> string_of_id id ^ "(" ^ Util.string_of_list "," string_of_typ_arg args ^ ")"
in
let string_of_binding (kid, arg) = string_of_kid kid ^ "=>" ^ string_of_typ_arg arg in
Util.zencode_string (Util.string_of_list ", " string_of_binding (KBindings.bindings instantiation))
let id_of_instantiation id instantiation =
let str = string_of_instantiation instantiation in
prepend_id (str ^ "#") id
let rec variant_generic_typ id defs =
match defs with
| DEF_aux (DEF_type (TD_aux (TD_variant (id', typq, _, _), _)), _) :: _ when Id.compare id id' = 0 ->
mk_typ
(Typ_app (id', List.map (fun kopt -> mk_typ_arg (A_typ (mk_typ (Typ_var (kopt_kid kopt))))) (quant_kopts typq)))
| _ :: defs -> variant_generic_typ id defs
| [] -> failwith ("No variant with id " ^ string_of_id id)
let instantiations_of spec id ast =
let instantiations = ref [] in
let inspect_exp = function
| E_aux (E_app (id', _), _) as exp when Id.compare id id' = 0 ->
let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in
instantiations := instantiation :: !instantiations;
exp
| exp -> exp
in
let inspect_pat = function
| P_aux (P_app (id', _), annot) as pat when Id.compare id id' = 0 -> begin
match Type_check.typ_of_annot annot with
| Typ_aux (Typ_app (variant_id, _), _) as typ ->
let open Type_check in
let instantiation =
unify (fst annot) (env_of_annot annot)
(tyvars_of_typ (variant_generic_typ variant_id ast.defs))
(variant_generic_typ variant_id ast.defs)
typ
in
instantiations := fix_instantiation spec instantiation :: !instantiations;
pat
| Typ_aux (Typ_id variant_id, _) -> pat
| _ -> failwith ("Union constructor " ^ string_of_pat pat ^ " has non-union type")
end
| pat -> pat
in
let rewrite_pat = { id_pat_alg with p_aux = (fun (pat, annot) -> inspect_pat (P_aux (pat, annot))) } in
let rewrite_exp =
{ id_exp_alg with pat_alg = rewrite_pat; e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) }
in
let _ =
rewrite_ast_base
{
rewriters_base with
rewrite_exp = (fun _ -> fold_exp rewrite_exp);
rewrite_pat = (fun _ -> fold_pat rewrite_pat);
}
ast
in
!instantiations
let rewrite_polymorphic_calls spec id ast =
let vs_ids = val_spec_ids ast.defs in
let rewrite_e_aux = function
| E_aux (E_app (id', args), annot) as exp when Id.compare id id' = 0 ->
let instantiation = fix_instantiation spec (Type_check.instantiation_of exp) in
let spec_id = id_of_instantiation id instantiation in
if IdSet.mem spec_id vs_ids then E_aux (E_app (spec_id, args), annot) else exp
| exp -> exp
in
let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> rewrite_e_aux (E_aux (exp, annot))) } in
rewrite_ast_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast
let rec typ_frees ?(exs = KidSet.empty) (Typ_aux (typ_aux, l)) =
match typ_aux with
| Typ_id v -> KidSet.empty
| Typ_var kid when KidSet.mem kid exs -> KidSet.empty
| Typ_var kid -> KidSet.singleton kid
| Typ_tuple typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_frees ~exs) typs)
| Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_frees ~exs) args)
| Typ_exist (kopts, nc, typ) -> typ_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ
| Typ_fn (arg_typs, ret_typ) ->
List.fold_left KidSet.union (typ_frees ~exs ret_typ) (List.map (typ_frees ~exs) arg_typs)
| Typ_bidir (t1, t2) -> KidSet.union (typ_frees ~exs t1) (typ_frees ~exs t2)
| Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown"
and typ_arg_frees ?(exs = KidSet.empty) (A_aux (typ_arg_aux, l)) =
match typ_arg_aux with A_nexp n -> KidSet.empty | A_typ typ -> typ_frees ~exs typ | A_bool _ -> KidSet.empty
let rec typ_int_frees ?(exs = KidSet.empty) (Typ_aux (typ_aux, l)) =
match typ_aux with
| Typ_id v -> KidSet.empty
| Typ_var kid -> KidSet.empty
| Typ_tuple typs -> List.fold_left KidSet.union KidSet.empty (List.map (typ_int_frees ~exs) typs)
| Typ_app (f, args) -> List.fold_left KidSet.union KidSet.empty (List.map (typ_arg_int_frees ~exs) args)
| Typ_exist (kopts, nc, typ) -> typ_int_frees ~exs:(KidSet.of_list (List.map kopt_kid kopts)) typ
| Typ_fn (arg_typs, ret_typ) ->
List.fold_left KidSet.union (typ_int_frees ~exs ret_typ) (List.map (typ_int_frees ~exs) arg_typs)
| Typ_bidir (t1, t2) -> KidSet.union (typ_int_frees ~exs t1) (typ_int_frees ~exs t2)
| Typ_internal_unknown -> Reporting.unreachable l __POS__ "escaped Typ_internal_unknown"
and typ_arg_int_frees ?(exs = KidSet.empty) (A_aux (typ_arg_aux, l)) =
match typ_arg_aux with
| A_nexp n -> KidSet.diff (tyvars_of_nexp n) exs
| A_typ typ -> typ_int_frees ~exs typ
| A_bool _ -> KidSet.empty
let rec remove_implicit (Typ_aux (aux, l)) =
match aux with
| Typ_internal_unknown -> Typ_aux (Typ_internal_unknown, l)
| Typ_tuple typs -> Typ_aux (Typ_tuple (List.map remove_implicit typs), l)
| Typ_fn (arg_typs, ret_typ) -> Typ_aux (Typ_fn (List.map remove_implicit arg_typs, remove_implicit ret_typ), l)
| Typ_bidir (typ1, typ2) -> Typ_aux (Typ_bidir (remove_implicit typ1, remove_implicit typ2), l)
| Typ_app (Id_aux (Id "implicit", _), args) -> Typ_aux (Typ_app (mk_id "atom", List.map remove_implicit_arg args), l)
| Typ_app (id, args) -> Typ_aux (Typ_app (id, List.map remove_implicit_arg args), l)
| Typ_id id -> Typ_aux (Typ_id id, l)
| Typ_exist (kopts, nc, typ) -> Typ_aux (Typ_exist (kopts, nc, remove_implicit typ), l)
| Typ_var v -> Typ_aux (Typ_var v, l)
and remove_implicit_arg (A_aux (aux, l)) =
match aux with A_typ typ -> A_aux (A_typ (remove_implicit typ), l) | arg -> A_aux (arg, l)
let kopt_arg = function
| KOpt_aux (KOpt_kind (K_aux (K_int, _), kid), _) -> arg_nexp (nvar kid)
| KOpt_aux (KOpt_kind (K_aux (K_type, _), kid), _) -> arg_typ (mk_typ (Typ_var kid))
| KOpt_aux (KOpt_kind (K_aux (K_bool, _), kid), _) -> arg_bool (nc_var kid)
let safe_instantiation instantiation =
let args =
List.map (fun (_, arg) -> kopts_of_typ_arg arg) (KBindings.bindings instantiation)
|> List.fold_left KOptSet.union KOptSet.empty
|> KOptSet.elements
in
List.fold_left
(fun (i, r) v ->
( KBindings.map (fun arg -> subst_kid typ_arg_subst (kopt_kid v) (prepend_kid "i#" (kopt_kid v)) arg) i,
KBindings.add (prepend_kid "i#" (kopt_kid v)) (kopt_arg v) r
)
)
(instantiation, KBindings.empty) args
let instantiate_constraints instantiation ncs =
List.map (fun c -> List.fold_left (fun c (v, a) -> constraint_subst v a c) c (KBindings.bindings instantiation)) ncs
let specialize_id_valspec spec instantiations id ast effect_info =
match split_defs (is_valspec id) ast.defs with
| None -> Reporting.unreachable (id_loc id) __POS__ ("Valspec " ^ string_of_id id ^ " does not exist!")
| Some (pre_defs, vs, post_defs) ->
let typschm, externs, annot, def_annot =
match vs with
| DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, _, externs), annot)), def_annot) ->
(typschm, externs, annot, def_annot)
| _ -> Reporting.unreachable (id_loc id) __POS__ "val-spec is not actually a val-spec"
in
let (TypSchm_aux (TypSchm_ts (typq, typ), _)) = typschm in
let spec_ids = ref IdSet.empty in
let specialize_instance instantiation =
let uninstantiated =
quant_kopts typq |> List.map kopt_kid
|> List.filter (fun v -> not (KBindings.mem v instantiation))
|> KidSet.of_list
in
let collect_kids kidsets = KidSet.elements (List.fold_left KidSet.union KidSet.empty kidsets) in
let typ_frees = KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_frees |> collect_kids in
let int_frees =
KBindings.bindings instantiation |> List.map snd |> List.map typ_arg_int_frees |> collect_kids
in
let typq, typ =
List.fold_left
(fun (typq, typ) free ->
if KidSet.mem free uninstantiated then (
let fresh_v = prepend_kid "o#" free in
(typquant_subst_kid free fresh_v typq, subst_kid typ_subst free fresh_v typ)
)
else (typq, typ)
)
(typq, typ) (typ_frees @ int_frees)
in
let safe_instantiation, reverse = safe_instantiation instantiation in
let typ =
remove_implicit (Type_check.subst_unifiers reverse (Type_check.subst_unifiers safe_instantiation typ))
in
let kopts, constraints = quant_split typq in
let constraints = instantiate_constraints safe_instantiation constraints in
let constraints = instantiate_constraints reverse constraints in
let kopts =
List.filter
(fun kopt -> not (spec.is_polymorphic kopt && KBindings.mem (kopt_kid kopt) safe_instantiation))
kopts
in
let typq =
if List.length (typ_frees @ int_frees) = 0 && List.length kopts = 0 then mk_typquant []
else
mk_typquant
(List.map (mk_qi_id K_type) typ_frees
@ List.map (mk_qi_id K_int) int_frees
@ List.map mk_qi_kopt kopts @ List.map mk_qi_nc constraints
)
in
let typschm = mk_typschm typq typ in
let spec_id = id_of_instantiation id instantiation in
if IdSet.mem spec_id !spec_ids then []
else begin
spec_ids := IdSet.add spec_id !spec_ids;
[DEF_aux (DEF_val (VS_aux (VS_val_spec (typschm, spec_id, externs), annot)), def_annot)]
end
in
let specializations = List.map specialize_instance instantiations |> List.concat in
let effect_info =
IdSet.fold (fun id' effect_info -> Effects.copy_function_effect id effect_info id') !spec_ids effect_info
in
({ ast with defs = pre_defs @ (vs :: specializations) @ post_defs }, effect_info)
let specialize_annotations instantiation fdef =
let open Type_check in
let rw_pat = { id_pat_alg with p_typ = (fun (typ, pat) -> P_typ (subst_unifiers instantiation typ, pat)) } in
let rw_exp =
{
id_exp_alg with
e_typ = (fun (typ, exp) -> E_typ (subst_unifiers instantiation typ, exp));
le_typ = (fun (typ, lexp) -> LE_typ (subst_unifiers instantiation typ, lexp));
pat_alg = rw_pat;
}
in
let fdef =
rewrite_fun
{ rewriters_base with rewrite_exp = (fun _ -> fold_exp rw_exp); rewrite_pat = (fun _ -> fold_pat rw_pat) }
fdef
in
match fdef with
| FD_aux (FD_function (rec_opt, _, funcls), annot) ->
FD_aux (FD_function (rec_opt, Typ_annot_opt_aux (Typ_annot_opt_none, Parse_ast.Unknown), funcls), annot)
let specialize_id_fundef instantiations id ast =
match split_defs (is_fundef id) ast.defs with
| None -> ast
| Some (pre_defs, DEF_aux (DEF_fundef fundef, def_annot), post_defs) ->
let spec_ids = ref IdSet.empty in
let specialize_fundef instantiation =
let spec_id = id_of_instantiation id instantiation in
if IdSet.mem spec_id !spec_ids then []
else begin
spec_ids := IdSet.add spec_id !spec_ids;
[DEF_aux (DEF_fundef (specialize_annotations instantiation (rename_fundef spec_id fundef)), def_annot)]
end
in
let fundefs = List.map specialize_fundef instantiations |> List.concat in
{ ast with defs = pre_defs @ (DEF_aux (DEF_fundef fundef, def_annot) :: fundefs) @ post_defs }
| Some _ -> assert false
let specialize_id_overloads instantiations id ast =
let ids = IdSet.of_list (List.map (id_of_instantiation id) instantiations) in
let rec rewrite_overloads defs =
match defs with
| DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: defs ->
let overloads =
List.concat (List.map (fun id' -> if Id.compare id' id = 0 then IdSet.elements ids else [id']) overloads)
in
DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: rewrite_overloads defs
| def :: defs -> def :: rewrite_overloads defs
| [] -> []
in
{ ast with defs = rewrite_overloads ast.defs }
let initial_calls =
ref
(IdSet.of_list
[
mk_id "main";
mk_id "__InitConfig";
mk_id "__SetConfig";
mk_id "__ListConfig";
mk_id "execute";
mk_id "decode";
mk_id "initialize_registers";
mk_id "prop";
mk_id "append_64" ;
]
)
let add_initial_calls ids = initial_calls := IdSet.union ids !initial_calls
let get_initial_calls () = IdSet.elements !initial_calls
let remove_unused_valspecs env ast =
let calls = ref !initial_calls in
let vs_ids = val_spec_ids ast.defs in
let inspect_exp = function
| E_aux (E_app (call, _), _) as exp ->
calls := IdSet.add call !calls;
exp
| exp -> exp
in
let rewrite_exp = { id_exp_alg with e_aux = (fun (exp, annot) -> inspect_exp (E_aux (exp, annot))) } in
let _ = rewrite_ast_base { rewriters_base with rewrite_exp = (fun _ -> fold_exp rewrite_exp) } ast in
let unused = IdSet.filter (fun vs_id -> not (IdSet.mem vs_id !calls)) vs_ids in
let rec remove_unused defs id =
match defs with
| def :: defs when is_fundef id def -> remove_unused defs id
| def :: defs when is_valspec id def -> remove_unused defs id
| DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: defs -> begin
match List.filter (fun id' -> Id.compare id id' <> 0) overloads with
| [] -> remove_unused defs id
| overloads -> DEF_aux (DEF_overload (overload_id, overloads), def_annot) :: remove_unused defs id
end
| def :: defs -> def :: remove_unused defs id
| [] -> []
in
List.fold_left (fun ast id -> { ast with defs = remove_unused ast.defs id }) ast (IdSet.elements unused)
let specialize_id spec id ast effect_info =
let instantiations = instantiations_of spec id ast in
let ast, effect_info = specialize_id_valspec spec instantiations id ast effect_info in
let ast = specialize_id_fundef instantiations id ast in
(specialize_id_overloads instantiations id ast, effect_info)
let reorder_typedefs ast =
let tdefs = ref [] in
let rec filter_typedefs = function
| (DEF_aux ((DEF_default _ | DEF_type _), _) as tdef) :: defs ->
tdefs := tdef :: !tdefs;
filter_typedefs defs
| def :: defs -> def :: filter_typedefs defs
| [] -> []
in
let others = filter_typedefs ast.defs in
{ ast with defs = List.rev !tdefs @ others }
let specialize_ids spec ids ast effect_info =
let t = Profile.start () in
let total = IdSet.cardinal ids in
let _, (ast, effect_info) =
List.fold_left
(fun (n, (ast, effect_info)) id ->
Util.progress "Specializing " (string_of_id id) n total;
(n + 1, specialize_id spec id ast effect_info)
)
(1, (ast, effect_info))
(IdSet.elements ids)
in
let ast = reorder_typedefs ast in
begin
match !opt_ddump_spec_ast with
| Some (f, i) ->
let filename = f ^ "_spec_" ^ string_of_int i ^ ".sail" in
let out_chan = open_out filename in
Pretty_print_sail.output_ast out_chan (Type_check.strip_ast ast);
close_out out_chan;
opt_ddump_spec_ast := Some (f, i + 1)
| None -> ()
end;
let ast, _ = Type_error.check Type_check.initial_env (Type_check.strip_ast ast) in
let _, ast =
List.fold_left
(fun (n, ast) id ->
Util.progress "Rewriting " (string_of_id id) n total;
(n + 1, rewrite_polymorphic_calls spec id ast)
)
(1, ast) (IdSet.elements ids)
in
let ast, env = Type_error.check Type_check.initial_env (Type_check.strip_ast ast) in
let ast = remove_unused_valspecs env ast in
Profile.finish "specialization pass" t;
(ast, env, effect_info)
let rec specialize_passes n spec env ast effect_info =
if n = 0 then (ast, env, effect_info)
else (
let ids = polymorphic_functions spec ast.defs in
if IdSet.is_empty ids then (ast, env, effect_info)
else (
let ast, env, effect_info = specialize_ids spec ids ast effect_info in
specialize_passes (n - 1) spec env ast effect_info
)
)
let specialize = specialize_passes (-1)
let () =
let open Interactive in
Action
(fun istate ->
let ast', env', effect_info' = specialize typ_specialization istate.env istate.ast istate.effect_info in
{ istate with ast = ast'; env = env'; effect_info = effect_info' }
)
|> register_command ~name:"specialize" ~help:"Specialize type variables in the AST"