package kaun

  1. Overview
  2. Docs

Source file ptree.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
module Dtype = Nx_core.Dtype

type tensor = P : ('a, 'layout) Rune.t -> tensor
type t = Tensor of tensor | List of t list | Dict of (string * t) list

let tensor t = Tensor (P t)
let list items = List items

let dict kvs =
  let tbl = Hashtbl.create (List.length kvs) in
  List.iter
    (fun (k, _) ->
      if Hashtbl.mem tbl k then
        invalid_arg ("Ptree.dict: duplicate key '" ^ k ^ "'")
      else Hashtbl.add tbl k ())
    kvs;
  Dict kvs

module Tensor = struct
  let dtype (P t) = Nx_core.Dtype.pack (Rune.dtype t)
  let shape (P t) = Rune.shape t
  let numel (P t) = Array.fold_left ( * ) 1 (Rune.shape t)

  let to_typed (type a l) (dtype : (a, l) Rune.dtype) (P t) :
      (a, l) Rune.t option =
    match Dtype.equal_witness (Rune.dtype t) dtype with
    | Some Type.Equal -> Some t
    | None -> None

  let to_typed_exn (type a l) (dtype : (a, l) Rune.dtype) (P t) : (a, l) Rune.t
      =
    match Dtype.equal_witness (Rune.dtype t) dtype with
    | Some Type.Equal -> t
    | None -> invalid_arg "Ptree.Tensor.to_typed_exn: dtype mismatch"
end

type 'r tensor_handler = { run : 'a 'layout. ('a, 'layout) Rune.t -> 'r }

let with_tensor (P t) handler = handler.run t

let cast_tensor_using_eq : type a layout b layout'.
    ((a, layout) Dtype.t, (b, layout') Dtype.t) Type.eq ->
    (a, layout) Rune.t ->
    (b, layout') Rune.t =
 fun eq tensor -> match eq with Type.Equal -> tensor

let as_tensor = function Tensor t -> Some t | _ -> None

let as_tensor_exn ?(ctx = "") tree =
  match as_tensor tree with
  | Some t -> t
  | None ->
      failwith
        (Printf.sprintf "Params.as_tensor_exn%s: expected tensor"
           (if ctx = "" then "" else " (" ^ ctx ^ ")"))

let rec map f tree =
  match tree with
  | Tensor tensor ->
      let tensor' =
        with_tensor tensor
          {
            run =
              (fun (type a) (type layout) (t : (a, layout) Rune.t) ->
                let f' =
                  (Obj.magic f : (a, layout) Rune.t -> (a, layout) Rune.t)
                in
                let result = f' t in
                match
                  Dtype.equal_witness (Rune.dtype t) (Rune.dtype result)
                with
                | Some Type.Equal -> P result
                | None -> invalid_arg "Ptree.map: function changed dtype");
          }
      in
      Tensor tensor'
  | List items -> List (List.map (fun item -> map f item) items)
  | Dict bindings -> Dict (List.map (fun (k, v) -> (k, map f v)) bindings)

let rec map2 f lhs rhs =
  match (lhs, rhs) with
  | Tensor lhs_tensor, Tensor rhs_tensor ->
      let tensor =
        with_tensor lhs_tensor
          {
            run =
              (fun (type a) (type layout) (t1 : (a, layout) Rune.t) ->
                with_tensor rhs_tensor
                  {
                    run =
                      (fun (type a')
                        (type layout')
                        (t2 : (a', layout') Rune.t)
                      ->
                        match
                          Dtype.equal_witness (Rune.dtype t1) (Rune.dtype t2)
                        with
                        | Some Type.Equal -> (
                            let f' =
                              (Obj.magic f
                                : (a, layout) Rune.t ->
                                  (a, layout) Rune.t ->
                                  (a, layout) Rune.t)
                            in
                            let result = f' t1 t2 in
                            match
                              Dtype.equal_witness (Rune.dtype t1)
                                (Rune.dtype result)
                            with
                            | Some Type.Equal -> P result
                            | None ->
                                invalid_arg "Ptree.map2: function changed dtype"
                            )
                        | None -> invalid_arg "Ptree.map2: dtype mismatch");
                  });
          }
      in
      Tensor tensor
  | List l_items, List r_items ->
      if List.length l_items <> List.length r_items then
        invalid_arg "Params.map2: list length mismatch";
      List (List.map2 (fun l r -> map2 f l r) l_items r_items)
  | Dict l_bindings, Dict r_bindings ->
      if List.length l_bindings <> List.length r_bindings then
        invalid_arg "Params.map2: dict length mismatch";
      let sorted_l =
        List.sort (fun (k1, _) (k2, _) -> String.compare k1 k2) l_bindings
      in
      let sorted_r =
        List.sort (fun (k1, _) (k2, _) -> String.compare k1 k2) r_bindings
      in
      let merged =
        List.map2
          (fun (k1, v1) (k2, v2) ->
            if k1 <> k2 then invalid_arg "Params.map2: dict key mismatch";
            (k1, map2 f v1 v2))
          sorted_l sorted_r
      in
      Dict merged
  | _ -> invalid_arg "Params.map2: structure mismatch"

let rec map_packed f tree =
  match tree with
  | Tensor t -> Tensor (f t)
  | List items -> List (List.map (fun item -> map_packed f item) items)
  | Dict bindings ->
      Dict (List.map (fun (k, v) -> (k, map_packed f v)) bindings)

let rec iter f tree =
  match tree with
  | Tensor t -> f t
  | List items -> List.iter (fun item -> iter f item) items
  | Dict bindings -> List.iter (fun (_, v) -> iter f v) bindings

let rec fold f acc tree =
  match tree with
  | Tensor t -> f acc t
  | List items -> List.fold_left (fun a item -> fold f a item) acc items
  | Dict bindings -> List.fold_left (fun a (_, v) -> fold f a v) acc bindings

let flatten tree =
  let rec collect acc = function
    | Tensor t -> t :: acc
    | List items -> List.fold_left collect acc items
    | Dict bindings -> List.fold_left (fun a (_, v) -> collect a v) acc bindings
  in
  let leaves = List.rev (collect [] tree) in
  let rebuild new_leaves =
    let idx = ref 0 in
    let rec aux = function
      | Tensor _ ->
          let t = List.nth new_leaves !idx in
          incr idx;
          Tensor t
      | List items -> List (List.map aux items)
      | Dict bindings -> Dict (List.map (fun (k, v) -> (k, aux v)) bindings)
    in
    aux tree
  in
  (leaves, rebuild)

module Path = struct
  type t = segment list
  and segment = Key of string | Index of int

  let root = []

  let of_string path_str =
    let len = String.length path_str in
    let rec parse i acc =
      if i >= len then List.rev acc
      else
        match path_str.[i] with
        | '.' -> parse (i + 1) acc
        | '[' ->
            let j = String.index_from path_str (i + 1) ']' in
            let idx = int_of_string (String.sub path_str (i + 1) (j - i - 1)) in
            parse (j + 1) (Index idx :: acc)
        | _ ->
            let next_dot =
              try Some (String.index_from path_str i '.')
              with Not_found -> None
            in
            let next_bracket =
              try Some (String.index_from path_str i '[')
              with Not_found -> None
            in
            let next_sep =
              match (next_dot, next_bracket) with
              | None, None -> len
              | Some idx, None -> idx
              | None, Some idx -> idx
              | Some dot_idx, Some bracket_idx -> min dot_idx bracket_idx
            in
            let key = String.sub path_str i (next_sep - i) in
            parse next_sep (Key key :: acc)
    in
    parse 0 []

  let to_string path =
    let buffer = Buffer.create 32 in
    let rec aux first = function
      | [] -> ()
      | Key k :: rest ->
          if not first then Buffer.add_char buffer '.';
          Buffer.add_string buffer k;
          aux false rest
      | Index i :: rest ->
          Buffer.add_char buffer '[';
          Buffer.add_string buffer (string_of_int i);
          Buffer.add_char buffer ']';
          aux false rest
    in
    aux true path;
    Buffer.contents buffer

  let key k p = p @ [ Key k ]
  let index i p = p @ [ Index i ]

  let rec get ~tree = function
    | [] -> Some tree
    | Key k :: rest -> (
        match tree with
        | Dict bindings -> (
            match List.assoc_opt k bindings with
            | Some v -> get ~tree:v rest
            | None -> None)
        | _ -> None)
    | Index i :: rest -> (
        match tree with
        | List items -> (
            match List.nth_opt items i with
            | Some v -> get ~tree:v rest
            | None -> None)
        | _ -> None)

  let placeholder_for = function
    | [] -> List []
    | Key _ :: _ -> Dict []
    | Index _ :: _ -> List []

  let rec set ~tree path ~value =
    match path with
    | [] -> value
    | Key k :: rest ->
        let bindings = match tree with Dict bs -> bs | _ -> [] in
        let rec rebuild acc = function
          | [] ->
              let child =
                if rest = [] then value
                else set ~tree:(placeholder_for rest) rest ~value
              in
              List.rev acc @ [ (k, child) ]
          | ((k', v) as binding) :: tail ->
              if String.equal k k' then
                let child =
                  if rest = [] then value else set ~tree:v rest ~value
                in
                List.rev acc @ ((k, child) :: tail)
              else rebuild (binding :: acc) tail
        in
        Dict (rebuild [] bindings)
    | Index i :: rest ->
        let items = match tree with List xs -> xs | _ -> [] in
        let filler = placeholder_for rest in
        let len = List.length items in
        let padded =
          if i < len then items
          else
            let extra = i - len + 1 in
            items @ List.init extra (fun _ -> filler)
        in
        let rec update idx = function
          | [] -> []
          | x :: xs ->
              if idx = 0 then
                let child =
                  if rest = [] then value else set ~tree:x rest ~value
                in
                child :: xs
              else x :: update (idx - 1) xs
        in
        List (update i padded)

  let update ~tree path ~f =
    match get ~tree path with
    | Some subtree -> set ~tree path ~value:(f subtree)
    | None -> invalid_arg "Params.Path.update: path not found"
end

let get ~path tree = Path.get ~tree path

let get_exn ~path tree =
  match get ~path tree with
  | Some t -> t
  | None ->
      invalid_arg
        (Printf.sprintf "Params.get_exn: path '%s' not found"
           (Path.to_string path))

let set ~path ~value tree = Path.set ~tree path ~value
let update ~path f tree = Path.update ~tree path ~f
let mem ~path tree = Option.is_some (get tree ~path)

let get_tensor ~path tree dtype =
  match get ~path tree with
  | Some (Tensor tensor) ->
      with_tensor tensor
        {
          run =
            (fun (type a) (type layout) (t : (a, layout) Rune.t) ->
              match Dtype.equal_witness (Rune.dtype t) dtype with
              | Some eq ->
                  let coerced = cast_tensor_using_eq eq t in
                  Some coerced
              | None ->
                  invalid_arg
                    (Printf.sprintf "Params.get_tensor: dtype mismatch at '%s'"
                       (Path.to_string path)));
        }
  | _ -> None

let get_tensor_exn ~path tree dtype =
  match get_tensor ~path tree dtype with
  | Some t -> t
  | None ->
      invalid_arg
        (Printf.sprintf "Params.get_tensor_exn: no tensor at '%s'"
           (Path.to_string path))

let flatten_with_paths tree =
  let rec go acc path = function
    | Tensor t -> (path, t) :: acc
    | List xs ->
        let rec loop acc i = function
          | [] -> acc
          | v :: vs ->
              let acc' = go acc (Path.index i path) v in
              loop acc' (i + 1) vs
        in
        loop acc 0 xs
    | Dict kvs ->
        List.fold_left (fun acc (k, v) -> go acc (Path.key k path) v) acc kvs
  in
  List.rev (go [] [] tree)

let filter_tensors tree pred =
  List.filter (fun (p, t) -> pred p t) (flatten_with_paths tree)

type float_dtype = F : (float, 'l) Rune.dtype -> float_dtype

let first_float_dtype tree =
  let rec go = function
    | Tensor (P t) ->
        let dt = Rune.dtype t in
        if Dtype.is_float dt then
          match dt with
          | Dtype.Float16 -> Some (F Dtype.Float16)
          | Dtype.Float32 -> Some (F Dtype.Float32)
          | Dtype.Float64 -> Some (F Dtype.Float64)
          | Dtype.BFloat16 -> Some (F Dtype.BFloat16)
          | Dtype.Float8_e4m3 -> Some (F Dtype.Float8_e4m3)
          | Dtype.Float8_e5m2 -> Some (F Dtype.Float8_e5m2)
          | _ -> None
        else None
    | List xs ->
        let rec find = function
          | [] -> None
          | v :: vs -> ( match go v with Some _ as r -> r | None -> find vs)
        in
        find xs
    | Dict kvs ->
        let rec find = function
          | [] -> None
          | (_, v) :: vs -> (
              match go v with Some _ as r -> r | None -> find vs)
        in
        find kvs
  in
  go tree

let first_float_dtype_exn tree =
  match first_float_dtype tree with
  | Some w -> w
  | None ->
      invalid_arg "Ptree.first_float_dtype_exn: no floating tensors in tree"

let zeros_like tree = map_packed (fun (P t) -> P (Rune.zeros_like t)) tree
let copy tree = map_packed (fun (P t) -> P (Rune.copy t)) tree
let count_tensors tree = fold (fun acc _ -> acc + 1) 0 tree
let count_parameters tree = fold (fun acc t -> acc + Tensor.numel t) 0 tree

module Dict = struct
  type fields = (string * t) list

  let fields_exn ?(ctx = "Ptree.Dict.fields_exn") = function
    | Dict fs -> fs
    | _ -> failwith (ctx ^ ": expected Dict node")

  let find name (fs : fields) = List.assoc_opt name fs

  let find_exn ?(ctx = "Ptree.Dict.find_exn") name (fs : fields) =
    match find name fs with
    | Some v -> v
    | None -> failwith (ctx ^ ": missing field '" ^ name ^ "'")

  let rec set key value (fs : fields) : fields =
    match fs with
    | [] -> [ (key, value) ]
    | (k, _) :: rest when String.equal k key -> (key, value) :: rest
    | binding :: rest -> binding :: set key value rest

  let update f key (fs : fields) =
    match find key fs with
    | Some v -> set key (f v) fs
    | None -> failwith ("Ptree.Dict.update: missing field '" ^ key ^ "'")

  let mem key (fs : fields) = List.exists (fun (k, _) -> String.equal k key) fs

  let get_tensor (fs : fields) ~name dtype =
    (* Make a tiny Dict subtree and delegate to typed path getter. *)
    get_tensor (Dict fs) ~path:(Path.key name []) dtype

  let get_tensor_exn (fs : fields) ~name dtype =
    get_tensor_exn (Dict fs) ~path:(Path.key name []) dtype
end

module List = struct
  let items_exn ?(ctx = "List_.items_exn") = function
    | List xs -> xs
    | _ -> failwith (ctx ^ ": expected List node")
end

let rec pp fmt = function
  | Tensor (P t) ->
      let shape_str =
        String.concat "×"
          (Stdlib.List.map string_of_int (Array.to_list (Rune.shape t)))
      in
      Format.fprintf fmt "Tensor(%s:%s)"
        (if shape_str = "" then "scalar" else shape_str)
        (Dtype.to_string (Rune.dtype t))
  | List items ->
      Format.fprintf fmt "[@[<hov>";
      Stdlib.List.iteri
        (fun i item ->
          if i > 0 then Format.fprintf fmt ",@ ";
          pp fmt item)
        items;
      Format.fprintf fmt "@]]"
  | Dict bindings ->
      Format.fprintf fmt "{@[<hov>";
      Stdlib.List.iteri
        (fun i (k, v) ->
          if i > 0 then Format.fprintf fmt ",@ ";
          Format.fprintf fmt "%s = %a" k pp v)
        bindings;
      Format.fprintf fmt "@]}"

let to_string tree = Format.asprintf "%a" pp tree