Source file infer.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
(** Type inference and checking *)
open Common
open Lplib
open Term
open Timed
open Print
(** Logging function for typing. *)
let log = Logger.make 'i' "infr" "type inference/checking"
let log = log.pp
type octxt = ctxt * bctxt
let boxed = snd
let classic = fst
let extend (cctx, bctx) v ?def ty =
  ((v, ty, def) :: cctx, if def <> None then bctx else (v, lift ty) :: bctx)
let unbox = Bindlib.unbox
(** Exception that may be raised by type inference. *)
exception NotTypable
(** [unif pb c a b] solves the unification problem [c ⊢ a ≡ b]. Current
    implementation collects constraints in {!val:constraints} then solves
    them at the end of type checking. *)
let unif : problem -> octxt -> term -> term -> unit =
 fun pb c a b ->
 if not (Eval.pure_eq_modulo (classic c) a b) then
 
   begin
     if Logger.log_enabled () then
       log (Color.yel "add constraint %a") constr
         (classic c, a, b);
     pb := {!pb with to_solve = (classic c, a, b) :: !pb.to_solve}
   end
(** {1 Handling coercions} *)
(** [reduce_coercions c t] tries to reduce coercions that are in term [t]. The
    reduction is attempted bottom up: first simplify leaves then go up to the
    root. It returns [None] if some coercions couldn't be simplified, and
    [Some t] where [t] is the simplified term otherwise. *)
let rec reduce_coercions : octxt -> term -> term option = fun c t ->
  
  let open Option.Monad in
  let is_coercion = function
    | Symb s when s == Coercion.coerce -> true
    | _ -> false
  in
  let (hd, args) = get_args t in
  if is_coercion hd then
    let* args = List.map (reduce_coercions c) args |> List.sequence_opt in
    
    let reduct = Eval.whnf (classic c) (add_args hd args) in
    let hd, args = get_args reduct in
    if is_coercion hd then None else reduce_coercions c (add_args hd args)
  else
    
    let reduce_coercions_binder b =
      let x, b = Bindlib.unbind b in
      let* b = reduce_coercions c b in
      return (Bindlib.(unbox (bind_var x (Term.lift b))))
    in
    match unfold t with
    | Patt _ | Wild | TEnv _ | TRef _ -> assert false
    | Plac _
    | Kind
    | Type | Vari _ | Symb _ | Meta _ -> return t
    | Appl (t, u) ->
        let* t = reduce_coercions c t in let* u = reduce_coercions c u in
        return (mk_Appl (t, u))
    | Abst (a, b) ->
        let* a = reduce_coercions c a in
        let* b = reduce_coercions_binder b in
        return (mk_Abst (a, b))
    | Prod (a, b) ->
        let* a = reduce_coercions c a in
        let* b = reduce_coercions_binder b in
        return (mk_Prod (a, b))
    | LLet (a, e, b) ->
        let* a = reduce_coercions c a in
        let* e = reduce_coercions c e in
        let* b = reduce_coercions_binder b in
        return (mk_LLet (a, e, b))
(** [coerce pb c t a b] coerces term [t] from type [a] to type [b] in context
    [c] and problem [pb]. *)
let rec coerce : problem -> octxt -> term -> term -> term -> term * bool =
  fun pb c t a b ->
  if Eval.pure_eq_modulo (classic c) a b then (t, false) else
     match Coercion.apply a b t |> reduce_coercions c with
     | None -> unif pb c a b; (t, false)
     | Some u ->
         if Logger.log_enabled () then
           log "Coerced [%a : %a <: %a : %a]" term t term a term u term b;
      
         let u, _, _ = infer pb c u in
         (u, true)
(** {1 Other rules} *)
(** NOTE: functions {!val:type_enforce}, {!val:force} and {!val:infer}
    return a boolean which is true iff the typechecked term has been
    modified. It allows to bypass reconstruction of some Bindlib terms (which
    call [lift |> bind_var x |> unbox]). It reduces the type checking time of
    Holide by 21%. *)
(** [type_enforce pb c a] returns a tuple [(a',s)] where [a'] is refined
    term [a] and [s] is a sort (Type or Kind) such that [a'] is of type
    [s]. *)
and type_enforce : problem -> octxt -> term -> term * term * bool =
 fun pb c a ->
  if Logger.log_enabled () then log "Type enforce [%a]" term a;
  let a, s, cui = infer pb c a in
  let sort =
    match unfold s with
    | Kind -> mk_Kind
    | Type -> mk_Type
    | _ -> mk_Type
    
  in
  let a, cu = coerce pb c a s sort in
  (a, sort, cui || cu)
(** [force pb c t a] returns a term [t'] such that [t'] has type [a],
    and [t'] is the refinement of [t]. *)
and force : problem -> octxt -> term -> term -> term * bool =
 fun pb c te ty ->
 if Logger.log_enabled () then
   log "Force [%a] of [%a]" term te term ty;
 match unfold te with
 | Plac true ->
     unif pb c ty mk_Type;
     (unbox (LibMeta.bmake pb (boxed c) _Type), true)
 | Plac false ->
     (unbox (LibMeta.bmake pb (boxed c) (lift ty)), true)
 | _ ->
     let (t, a, cui) = infer pb c te in
     let t, cu = coerce pb c t a ty in
     (t, cu || cui)
and infer_aux : problem -> octxt -> term -> term * term * bool =
 fun pb c t ->
  match unfold t with
  | Patt _ -> assert false
  | TEnv _ -> assert false
  | Kind -> assert false
  | Wild -> assert false
  | TRef _ -> assert false
  | Type -> (mk_Type, mk_Kind, false)
  | Vari x ->
      let a = try Ctxt.type_of x (classic c) with Not_found -> assert false in
      (t, a, false)
  | Symb s -> (t, !(s.sym_type), false)
  | Plac true ->
      let m = LibMeta.bmake pb (boxed c) _Type in
      (unbox m, mk_Type, true)
  | Plac false ->
      let mt = LibMeta.bmake pb (boxed c) _Type in
      let m = LibMeta.bmake pb (boxed c) mt in
      (unbox m, unbox mt, true)
  
  | (Meta (m, ts)) as t ->
      let cu = Stdlib.ref false in
      let rec ref_esubst i range =
        
        if i >= Array.length ts then range else
          match unfold range with
          | Prod(ai, b) ->
              let (tsi, cuf) = force pb c ts.(i) ai in
              ts.(i) <- tsi;
              Stdlib.(cu := !cu || cuf);
              let b = Bindlib.subst b ts.(i) in
              ref_esubst (i + 1) b
          | _ ->
              
              assert false
      in
      let range = ref_esubst 0 !(m.meta_type) in
      (t, range, Stdlib.(!cu))
  | LLet (t_ty, t, u) as top ->
      
      let t_ty, _, cu_t_ty = type_enforce pb c t_ty in
      
      let t, cu_t = force pb c t t_ty in
      
      let (x, u) = Bindlib.unbind u in
      let c = extend c x ~def:t t_ty in
      
      let u, u_ty, cu_u = infer pb c u in
      ( match unfold u_ty with
        | Kind ->
            Error.fatal_msg "Let bindings cannot have a body of type Kind.";
            Error.fatal_msg "Body of let binding [%a] has type Kind."
              term u;
            raise NotTypable
        | _ -> () );
      let u_ty = Bindlib.(u_ty |> lift |> bind_var x |> unbox) in
      let top_ty = mk_LLet (t_ty, t, u_ty) in
      let cu = cu_t_ty || cu_t || cu_u in
      let top =
        if cu then
          let u = Bindlib.(u |> lift |> bind_var x |> unbox) in
          mk_LLet(t_ty, t, u)
        else top
      in
      (top, top_ty, cu)
  | Abst (dom, b) as top ->
      
      let dom, cu_dom = force pb c dom mk_Type in
      let (x, b) = Bindlib.unbind b in
      let c = extend c x dom in
      let b, range, cu_b = infer pb c b in
      let range = Bindlib.(lift range |> bind_var x |> unbox) in
      let top_ty = mk_Prod (dom, range) in
      let cu = cu_b || cu_dom in
      let top =
        if cu then
          let b = Bindlib.(lift b |> bind_var x |> unbox) in
          mk_Abst (dom, b)
        else top
      in
      (top, top_ty, cu)
  | Prod (dom, b) as top ->
      
      let dom, cu_dom = force pb c dom mk_Type in
      let (x, b) = Bindlib.unbind b in
      let c = extend c x dom in
      let b, b_s, cu_b = type_enforce pb c b in
      let cu = cu_b || cu_dom in
      let top =
        if cu then
          let b = Bindlib.(lift b |> bind_var x |> unbox) in
          mk_Prod (dom, b)
        else top
      in
      (top, b_s, cu)
  | Appl (t, u) as top -> (
      let t, t_ty, cu_t = infer pb c t in
      let return m t u range =
        let ty = Bindlib.subst range u and cu = cu_t || m in
        if cu then (mk_Appl (t, u), ty, cu) else (top, ty, cu)
      in
      match Eval.whnf (classic c) t_ty with
      | Prod (dom, range) ->
          if Logger.log_enabled () then
            log "Appl-prod arg [%a]" term u;
          let u, cu_u = force pb c u dom in
          return cu_u t u range
      | Meta (_, _) ->
          let u, u_ty, cu_u = infer pb c u in
          let range =
            unbox (LibMeta.bmake_codomain pb (boxed c) (lift u_ty))
          in
          unif pb c t_ty (mk_Prod (u_ty, range));
          return cu_u t u range
      | t_ty ->
          let domain = LibMeta.bmake pb (boxed c) _Type in
          let range = LibMeta.bmake_codomain pb (boxed c) domain in
          let domain = unbox domain
          and range = unbox range in
          let t, cu_t' = coerce pb c t t_ty (mk_Prod (domain, range)) in
          if Logger.log_enabled () then
            log "Appl-default arg [%a]" term u;
          let u, cu_u = force pb c u domain in
          return (cu_t' || cu_u) t u range )
and infer : problem -> octxt -> term -> term * term * bool = fun pb c t ->
  if Logger.log_enabled () then log "Infer [%a]" term t;
  let t, t_ty, cu = infer_aux pb c t in
  if Logger.log_enabled () then log "Inferred [%a:@ %a]" term t term t_ty;
  (t, t_ty, cu)
(** {b NOTE} when unbinding a binder [b] (e.g. when inferring the type of an
    abstraction [λ x, e]) in context [c], [c] is always extended, even if
    binder [b] is constant. This is because during typechecking, the context
    must contain all variables traversed to build appropriate meta-variables.
    Otherwise, the term [λ a: _, λ b: _, b] will be transformed to [λ _: ?1,
    λ b: ?2, b] whereas it should be [λ a: ?1.[], λ b: ?2.[a], b] *)
(** [noexn f cs c args] initialises {!val:constraints} to [cs],
    calls [f c args] and returns [Some(r,cs)] where [r] is the value of
    the call to [f] and [cs] is the list of constraints gathered by
    [f]. Function [f] may raise [NotTypable], in which case [None] is
    returned. *)
let noexn :
  (problem -> octxt -> 'a -> 'b) -> problem -> ctxt -> 'a -> 'b option =
  fun f pb c args ->
  try Some (f pb (c, Ctxt.box_context c) args)
  with NotTypable -> None
let infer_noexn pb c t : (term * term) option =
  if Logger.log_enabled () then log "Top infer %a%a" ctxt c term t;
  let infer pb c t = let (t,t_ty,_) = infer pb c t in (t, t_ty) in
  noexn infer pb c t
let check_noexn pb c t a : term option =
  if Logger.log_enabled () then log "Top check \"%a\"" typing (c, t, a);
  let force pb c (t, a) = fst (force pb c t a) in
  noexn force pb c (t, a)
let check_sort_noexn pb c t : (term * term) option =
  if Logger.log_enabled () then
    log "Top check sort %a%a" ctxt c term t;
  let type_enforce pb c t = let (t, s, _) = type_enforce pb c t in (t, s) in
  noexn type_enforce pb c t