package kaun

  1. Overview
  2. Docs
Flax-inspired neural network library for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c

doc/src/kaun.huggingface/kaun_huggingface.ml.html

Source file kaun_huggingface.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
open Rune

(** Implementation of HuggingFace model hub integration for Kaun *)

(* Types *)

type model_id = string
type revision = Latest | Tag of string | Commit of string
type cache_dir = string

type download_progress = {
  downloaded_bytes : int;
  total_bytes : int option;
  rate : float;
}

type 'a download_result = Cached of 'a | Downloaded of 'a * download_progress

(* Configuration *)

module Config = struct
  type t = {
    cache_dir : cache_dir;
    token : string option;
    offline_mode : bool;
    force_download : bool;
    show_progress : bool;
  }

  let default =
    {
      cache_dir =
        Filename.concat
          (try Sys.getenv "HOME" with Not_found -> "/tmp")
          ".cache/kaun/huggingface";
      token = None;
      offline_mode = false;
      force_download = false;
      show_progress = true;
    }

  let from_env () =
    let get_env_opt var = try Some (Sys.getenv var) with Not_found -> None in
    let get_env_bool var =
      match get_env_opt var with Some "true" | Some "1" -> true | _ -> false
    in
    {
      cache_dir =
        (match get_env_opt "KAUN_HF_CACHE_DIR" with
        | Some dir -> dir
        | None -> default.cache_dir);
      token = get_env_opt "KAUN_HF_TOKEN";
      offline_mode = get_env_bool "KAUN_HF_OFFLINE_MODE";
      force_download = get_env_bool "KAUN_HF_FORCE_DOWNLOAD";
      show_progress =
        (match get_env_opt "KAUN_HF_SHOW_PROGRESS" with
        | Some "false" | Some "0" -> false
        | _ -> true);
    }
end

(* Model Registry *)

module Registry = struct
  type ('params, 'a, 'dev) model_spec = {
    architecture : string;
    config_file : string;
    weight_files : string list;
    load_config : Yojson.Safe.t -> 'params;
    build_params : dtype:(float, 'a) dtype -> 'params -> 'a Kaun.params;
  }

  let registry : (string, Obj.t) Hashtbl.t = Hashtbl.create 10
  let register name spec = Hashtbl.replace registry name (Obj.repr spec)

  let get name =
    try Some (Obj.obj (Hashtbl.find registry name)) with Not_found -> None
end

(* Utilities *)

let ensure_dir dir =
  let rec mkdir_p path =
    if not (Sys.file_exists path) then (
      mkdir_p (Filename.dirname path);
      try Unix.mkdir path 0o755 with Unix.Unix_error (Unix.EEXIST, _, _) -> ())
  in
  mkdir_p dir

let hub_url ~model_id ~filename ~revision =
  let revision_str =
    match revision with
    | Latest -> "main"
    | Tag tag -> tag
    | Commit commit -> commit
  in
  Printf.sprintf "https://huggingface.co/%s/resolve/%s/%s" model_id revision_str
    filename

let cache_path config ~model_id ~filename ~revision =
  let revision_str =
    match revision with
    | Latest -> "main"
    | Tag tag -> "tags/" ^ tag
    | Commit commit -> "commits/" ^ commit
  in
  let model_dir = String.map (fun c -> if c = '/' then '-' else c) model_id in
  Filename.concat config.Config.cache_dir
    (Filename.concat model_dir (Filename.concat revision_str filename))

(* Core Loading Functions *)

let download_with_progress ~url ~dest ~show_progress =
  ensure_dir (Filename.dirname dest);

  (* Use curl with progress bar if requested *)
  let progress_flag = if show_progress then "" else "-s" in
  let cmd = Printf.sprintf "curl -L %s -o %s '%s'" progress_flag dest url in

  if show_progress then Printf.printf "Downloading from %s...\n%!" url;

  let start_time = Unix.gettimeofday () in

  match Unix.system cmd with
  | Unix.WEXITED 0 ->
      let elapsed = Unix.gettimeofday () -. start_time in
      let stats = Unix.stat dest in
      let bytes = stats.Unix.st_size in
      let rate = float_of_int bytes /. elapsed in
      { downloaded_bytes = bytes; total_bytes = Some bytes; rate }
  | _ -> failwith (Printf.sprintf "Failed to download %s" url)

let download_file ?(config = Config.default) ?(revision = Latest) ~model_id
    ~filename () =
  let local_path = cache_path config ~model_id ~filename ~revision in

  (* Check if already cached *)
  if Sys.file_exists local_path && not config.force_download then
    Cached local_path
  else if config.offline_mode then
    failwith (Printf.sprintf "File not in cache (offline mode): %s" local_path)
  else
    (* Download the file *)
    let url = hub_url ~model_id ~filename ~revision in
    let progress =
      download_with_progress ~url ~dest:local_path
        ~show_progress:config.show_progress
    in
    Downloaded (local_path, progress)

let load_safetensors ?(config = Config.default) ?(revision = Latest) ~model_id
    ~dtype () =
  (* Try common safetensors filenames *)
  let filenames =
    [
      "model.safetensors";
      "pytorch_model.safetensors";
      "model-00001-of-00001.safetensors";
    ]
  in

  let rec try_files = function
    | [] ->
        failwith (Printf.sprintf "No safetensors file found for %s" model_id)
    | filename :: rest -> (
        try
          let result = download_file ~config ~revision ~model_id ~filename () in
          let local_path =
            match result with
            | Cached path -> path
            | Downloaded (path, _) -> path
          in

          (* Load using Kaun_checkpoint *)
          let checkpointer = Kaun.Checkpoint.Checkpointer.create () in
          let params =
            Kaun.Checkpoint.Checkpointer.restore_file checkpointer
              ~path:local_path ~dtype
          in

          match result with
          | Cached _ -> Cached params
          | Downloaded (_, progress) -> Downloaded (params, progress)
        with _ -> try_files rest)
  in

  try_files filenames

let load_config ?(config = Config.default) ?(revision = Latest) ~model_id () =
  let result =
    download_file ~config ~revision ~model_id ~filename:"config.json" ()
  in
  let local_path =
    match result with Cached path -> path | Downloaded (path, _) -> path
  in

  let json = Yojson.Safe.from_file local_path in

  match result with
  | Cached _ -> Cached json
  | Downloaded (_, progress) -> Downloaded (json, progress)

(* High-level Model Loading *)

let from_pretrained ?(config = Config.default) ?(revision = Latest) ~model_id
    ~dtype () =
  (* Load safetensors weights *)
  match load_safetensors ~config ~revision ~model_id ~dtype () with
  | Cached params -> params
  | Downloaded (params, _) -> params

(* Utilities *)

let list_cached_models ?(config = Config.default) () =
  if not (Sys.file_exists config.cache_dir) then []
  else
    let entries = Sys.readdir config.cache_dir in
    Array.to_list entries
    |> List.filter (fun e ->
           Sys.is_directory (Filename.concat config.cache_dir e))
    |> List.map (fun e -> String.map (fun c -> if c = '-' then '/' else c) e)

let clear_cache ?(config = Config.default) ?model_id () =
  let rec rm_rf path =
    if Sys.is_directory path then (
      let entries = Sys.readdir path in
      Array.iter (fun entry -> rm_rf (Filename.concat path entry)) entries;
      Unix.rmdir path)
    else Sys.remove path
  in

  match model_id with
  | Some id ->
      let model_dir = String.map (fun c -> if c = '/' then '-' else c) id in
      let path = Filename.concat config.cache_dir model_dir in
      if Sys.file_exists path then rm_rf path
  | None -> if Sys.file_exists config.cache_dir then rm_rf config.cache_dir

let get_model_info model_id =
  let url = Printf.sprintf "https://huggingface.co/api/models/%s" model_id in
  let cmd = Printf.sprintf "curl -s '%s'" url in

  let ic = Unix.open_process_in cmd in
  let rec read_all acc =
    try
      let line = input_line ic in
      read_all (acc ^ line ^ "\n")
    with End_of_file -> acc
  in
  let output = read_all "" in
  let status = Unix.close_process_in ic in

  match status with
  | Unix.WEXITED 0 -> (
      try Ok (Yojson.Safe.from_string output)
      with _ -> Error "Failed to parse JSON response")
  | _ -> Error "Failed to fetch model info"