package kaun

  1. Overview
  2. Docs
Flax-inspired neural network library for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha2.tbz
sha256=93abc49d075a1754442ccf495645bc4fdc83e4c66391ec8aca8fa15d2b4f44d2
sha512=5eb958c51f30ae46abded4c96f48d1825f79c7ce03f975f9a6237cdfed0d62c0b4a0774296694def391573d849d1f869919c49008acffca95946b818ad325f6f

doc/src/kaun.huggingface/kaun_huggingface.ml.html

Source file kaun_huggingface.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
open Rune

(** Implementation of HuggingFace model hub integration for Kaun *)

(* Types *)

type model_id = string
type revision = Latest | Tag of string | Commit of string
type cache_dir = string

type download_progress = {
  downloaded_bytes : int;
  total_bytes : int option;
  rate : float;
}

type 'a download_result = Cached of 'a | Downloaded of 'a * download_progress

(* Configuration *)

module Config = struct
  type t = {
    cache_dir : cache_dir;
    token : string option;
    offline_mode : bool;
    force_download : bool;
    show_progress : bool;
  }

  let default =
    {
      cache_dir =
        Filename.concat
          (try Sys.getenv "HOME" with Not_found -> "/tmp")
          ".cache/kaun/huggingface";
      token = None;
      offline_mode = false;
      force_download = false;
      show_progress = true;
    }

  let from_env () =
    let get_env_opt var = try Some (Sys.getenv var) with Not_found -> None in
    let get_env_bool var =
      match get_env_opt var with Some "true" | Some "1" -> true | _ -> false
    in
    {
      cache_dir =
        (match get_env_opt "KAUN_HF_CACHE_DIR" with
        | Some dir -> dir
        | None -> default.cache_dir);
      token = get_env_opt "KAUN_HF_TOKEN";
      offline_mode = get_env_bool "KAUN_HF_OFFLINE_MODE";
      force_download = get_env_bool "KAUN_HF_FORCE_DOWNLOAD";
      show_progress =
        (match get_env_opt "KAUN_HF_SHOW_PROGRESS" with
        | Some "false" | Some "0" -> false
        | _ -> true);
    }
end

(* Model Registry *)

module Registry = struct
  type ('params, 'a, 'dev) model_spec = {
    architecture : string;
    config_file : string;
    weight_files : string list;
    load_config : Yojson.Safe.t -> 'params;
    build_params : dtype:(float, 'a) dtype -> 'params -> Kaun.params;
  }

  let registry : (string, Obj.t) Hashtbl.t = Hashtbl.create 10
  let register name spec = Hashtbl.replace registry name (Obj.repr spec)

  let get name =
    try Some (Obj.obj (Hashtbl.find registry name)) with Not_found -> None
end

(* Utilities *)

let ensure_dir dir =
  let rec mkdir_p path =
    if not (Sys.file_exists path) then (
      mkdir_p (Filename.dirname path);
      try Unix.mkdir path 0o755 with Unix.Unix_error (Unix.EEXIST, _, _) -> ())
  in
  mkdir_p dir

let hub_url ~model_id ~filename ~revision =
  let revision_str =
    match revision with
    | Latest -> "main"
    | Tag tag -> tag
    | Commit commit -> commit
  in
  Printf.sprintf "https://huggingface.co/%s/resolve/%s/%s" model_id revision_str
    filename

let cache_path config ~model_id ~filename ~revision =
  let revision_str =
    match revision with
    | Latest -> "main"
    | Tag tag -> "tags/" ^ tag
    | Commit commit -> "commits/" ^ commit
  in
  let model_dir = String.map (fun c -> if c = '/' then '-' else c) model_id in
  Filename.concat config.Config.cache_dir
    (Filename.concat model_dir (Filename.concat revision_str filename))

(* Core Loading Functions *)

let download_with_progress ~url ~dest ~show_progress =
  ensure_dir (Filename.dirname dest);

  (* Use curl with progress bar if requested *)
  let progress_flag = if show_progress then "" else "-s" in
  let cmd = Printf.sprintf "curl -L %s -o %s '%s'" progress_flag dest url in

  if show_progress then Printf.printf "Downloading from %s...\n%!" url;

  let start_time = Unix.gettimeofday () in

  match Unix.system cmd with
  | Unix.WEXITED 0 ->
      let elapsed = Unix.gettimeofday () -. start_time in
      let stats = Unix.stat dest in
      let bytes = stats.Unix.st_size in
      let rate = float_of_int bytes /. elapsed in
      { downloaded_bytes = bytes; total_bytes = Some bytes; rate }
  | _ -> failwith (Printf.sprintf "Failed to download %s" url)

let download_file ?(config = Config.default) ?(revision = Latest) ~model_id
    ~filename () =
  let local_path = cache_path config ~model_id ~filename ~revision in

  (* Check if already cached *)
  if Sys.file_exists local_path && not config.force_download then
    Cached local_path
  else if config.offline_mode then
    failwith (Printf.sprintf "File not in cache (offline mode): %s" local_path)
  else
    (* Download the file *)
    let url = hub_url ~model_id ~filename ~revision in
    let progress =
      download_with_progress ~url ~dest:local_path
        ~show_progress:config.show_progress
    in
    Downloaded (local_path, progress)

let load_safetensors ?(config = Config.default) ?(revision = Latest) ~model_id
    () =
  let load_entries ?allowed_names path =
    let archive = Nx_io.load_safetensor path in
    match allowed_names with
    | None ->
        Hashtbl.fold
          (fun name (Nx_io.P nx_tensor) acc ->
            let rune_tensor = Rune.of_nx nx_tensor in
            (name, Kaun.Ptree.tensor rune_tensor) :: acc)
          archive []
    | Some names ->
        List.map
          (fun name ->
            match Hashtbl.find_opt archive name with
            | Some (Nx_io.P nx_tensor) ->
                let rune_tensor = Rune.of_nx nx_tensor in
                (name, Kaun.Ptree.tensor rune_tensor)
            | None ->
                failwith
                  (Printf.sprintf
                     "Shard for %s missing tensor '%s' while loading %s"
                     model_id name path))
          names
  in

  let params_from_entries entries = Kaun.Ptree.dict entries in

  let apply_progress progress_list params =
    let combine acc next =
      let accumulated_time =
        if acc.rate <= 0. then 0.
        else float_of_int acc.downloaded_bytes /. acc.rate
      in
      let next_time =
        if next.rate <= 0. then 0.
        else float_of_int next.downloaded_bytes /. next.rate
      in
      let total_bytes = acc.downloaded_bytes + next.downloaded_bytes in
      let total_time = accumulated_time +. next_time in
      let rate =
        if total_time <= 0. then acc.rate
        else float_of_int total_bytes /. total_time
      in
      let total_size =
        match (acc.total_bytes, next.total_bytes) with
        | Some a, Some b -> Some (a + b)
        | _ -> None
      in
      { downloaded_bytes = total_bytes; total_bytes = total_size; rate }
    in
    match progress_list with
    | [] -> Cached params
    | first :: rest ->
        let total = List.fold_left combine first rest in
        Downloaded (params, total)
  in

  let index_filenames =
    [ "model.safetensors.index.json"; "pytorch_model.bin.index.json" ]
  in

  let fallback_filenames =
    [
      "model.safetensors";
      "pytorch_model.safetensors";
      "model-00001-of-00001.safetensors";
    ]
  in

  let local_path_of_result = function
    | Cached path -> path
    | Downloaded (path, _) -> path
  in

  let progress_of_result acc = function
    | Cached _ -> acc
    | Downloaded (_, progress) -> progress :: acc
  in

  let attempt_index filename =
    try
      let result = download_file ~config ~revision ~model_id ~filename () in
      let index_path = local_path_of_result result in
      let progress_acc = progress_of_result [] result in

      let json = Yojson.Safe.from_file index_path in
      let weight_map =
        match Yojson.Safe.Util.member "weight_map" json with
        | `Assoc entries ->
            List.map
              (fun (tensor_name, shard_json) ->
                match shard_json with
                | `String shard -> (tensor_name, shard)
                | _ -> failwith "Invalid shard entry in weight_map")
              entries
        | _ -> failwith "Missing weight_map in index file"
      in

      if weight_map = [] then failwith "Empty weight_map in index file";

      let shards_by_file = Hashtbl.create 8 in
      let file_order = ref [] in
      List.iter
        (fun (tensor_name, shard_filename) ->
          let existing = Hashtbl.find_opt shards_by_file shard_filename in
          match existing with
          | Some tensors ->
              Hashtbl.replace shards_by_file shard_filename
                (tensor_name :: tensors)
          | None ->
              Hashtbl.add shards_by_file shard_filename [ tensor_name ];
              file_order := shard_filename :: !file_order)
        weight_map;

      let file_order = List.rev !file_order in
      let seen_tensors = Hashtbl.create (List.length weight_map) in

      let progress_list, entries_rev =
        List.fold_left
          (fun (progresses, acc_entries_rev) shard_filename ->
            let shard_result =
              download_file ~config ~revision ~model_id ~filename:shard_filename
                ()
            in
            let shard_path = local_path_of_result shard_result in
            let progresses = progress_of_result progresses shard_result in
            let tensors =
              match Hashtbl.find_opt shards_by_file shard_filename with
              | Some names -> List.rev names
              | None ->
                  failwith
                    (Printf.sprintf "Shard mapping missing for file '%s' in %s"
                       shard_filename filename)
            in
            let new_entries = load_entries ~allowed_names:tensors shard_path in
            List.iter
              (fun (tensor_name, _) ->
                if Hashtbl.mem seen_tensors tensor_name then
                  failwith
                    (Printf.sprintf
                       "Tensor '%s' defined multiple times across shards"
                       tensor_name);
                Hashtbl.add seen_tensors tensor_name ())
              new_entries;
            let acc_entries_rev =
              List.fold_left
                (fun acc entry -> entry :: acc)
                acc_entries_rev (List.rev new_entries)
            in
            (progresses, acc_entries_rev))
          (progress_acc, []) file_order
      in

      if Hashtbl.length seen_tensors <> List.length weight_map then
        failwith
          "Incomplete shard loading: not all tensors listed in the weight_map \
           were found";

      let entries = List.rev entries_rev in
      let params = params_from_entries entries in
      Some (apply_progress progress_list params)
    with
    | Failure msg
      when String.starts_with ~prefix:"Failed to download" msg
           || String.starts_with ~prefix:"No such file" msg
           || String.starts_with ~prefix:"File not in cache (offline mode)" msg
      ->
        None
    | Yojson.Json_error _ -> None
    | Sys_error _ -> None
  in

  let rec try_indexes = function
    | [] -> None
    | filename :: rest -> (
        match attempt_index filename with
        | Some result -> Some result
        | None -> try_indexes rest)
  in

  match try_indexes index_filenames with
  | Some result -> result
  | None ->
      let rec try_files = function
        | [] ->
            failwith
              (Printf.sprintf "No safetensors file found for %s" model_id)
        | filename :: rest -> (
            try
              let result =
                download_file ~config ~revision ~model_id ~filename ()
              in
              let local_path = local_path_of_result result in
              let entries = load_entries local_path in
              let params = params_from_entries entries in
              match result with
              | Cached _ -> Cached params
              | Downloaded (_, progress) -> Downloaded (params, progress)
            with
            | Failure msg
              when String.starts_with ~prefix:"Failed to download" msg
                   || String.starts_with ~prefix:"No such file" msg ->
                try_files rest
            | Failure _ -> try_files rest
            | _ -> try_files rest)
      in
      try_files fallback_filenames

let load_config ?(config = Config.default) ?(revision = Latest) ~model_id () =
  let result =
    download_file ~config ~revision ~model_id ~filename:"config.json" ()
  in
  let local_path =
    match result with Cached path -> path | Downloaded (path, _) -> path
  in

  let json = Yojson.Safe.from_file local_path in

  match result with
  | Cached _ -> Cached json
  | Downloaded (_, progress) -> Downloaded (json, progress)

(* High-level Model Loading *)

let from_pretrained ?(config = Config.default) ?(revision = Latest) ~model_id ()
    =
  (* Load safetensors weights *)
  match load_safetensors ~config ~revision ~model_id () with
  | Cached params -> params
  | Downloaded (params, _) -> params

(* Utilities *)

let list_cached_models ?(config = Config.default) () =
  if not (Sys.file_exists config.cache_dir) then []
  else
    let entries = Sys.readdir config.cache_dir in
    Array.to_list entries
    |> List.filter (fun e ->
           Sys.is_directory (Filename.concat config.cache_dir e))
    |> List.map (fun e -> String.map (fun c -> if c = '-' then '/' else c) e)

let clear_cache ?(config = Config.default) ?model_id () =
  let rec rm_rf path =
    if Sys.is_directory path then (
      let entries = Sys.readdir path in
      Array.iter (fun entry -> rm_rf (Filename.concat path entry)) entries;
      Unix.rmdir path)
    else Sys.remove path
  in

  match model_id with
  | Some id ->
      let model_dir = String.map (fun c -> if c = '/' then '-' else c) id in
      let path = Filename.concat config.cache_dir model_dir in
      if Sys.file_exists path then rm_rf path
  | None -> if Sys.file_exists config.cache_dir then rm_rf config.cache_dir

let get_model_info model_id =
  let url = Printf.sprintf "https://huggingface.co/api/models/%s" model_id in
  let cmd = Printf.sprintf "curl -s '%s'" url in

  let ic = Unix.open_process_in cmd in
  let rec read_all acc =
    try
      let line = input_line ic in
      read_all (acc ^ line ^ "\n")
    with End_of_file -> acc
  in
  let output = read_all "" in
  let status = Unix.close_process_in ic in

  match status with
  | Unix.WEXITED 0 -> (
      try Ok (Yojson.Safe.from_string output)
      with _ -> Error "Failed to parse JSON response")
  | _ -> Error "Failed to fetch model info"