package kaun
Flax-inspired neural network library for OCaml
Install
dune-project
Dependency
Authors
Maintainers
Sources
raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c
doc/src/kaun.huggingface/kaun_huggingface.ml.html
Source file kaun_huggingface.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
open Rune (** Implementation of HuggingFace model hub integration for Kaun *) (* Types *) type model_id = string type revision = Latest | Tag of string | Commit of string type cache_dir = string type download_progress = { downloaded_bytes : int; total_bytes : int option; rate : float; } type 'a download_result = Cached of 'a | Downloaded of 'a * download_progress (* Configuration *) module Config = struct type t = { cache_dir : cache_dir; token : string option; offline_mode : bool; force_download : bool; show_progress : bool; } let default = { cache_dir = Filename.concat (try Sys.getenv "HOME" with Not_found -> "/tmp") ".cache/kaun/huggingface"; token = None; offline_mode = false; force_download = false; show_progress = true; } let from_env () = let get_env_opt var = try Some (Sys.getenv var) with Not_found -> None in let get_env_bool var = match get_env_opt var with Some "true" | Some "1" -> true | _ -> false in { cache_dir = (match get_env_opt "KAUN_HF_CACHE_DIR" with | Some dir -> dir | None -> default.cache_dir); token = get_env_opt "KAUN_HF_TOKEN"; offline_mode = get_env_bool "KAUN_HF_OFFLINE_MODE"; force_download = get_env_bool "KAUN_HF_FORCE_DOWNLOAD"; show_progress = (match get_env_opt "KAUN_HF_SHOW_PROGRESS" with | Some "false" | Some "0" -> false | _ -> true); } end (* Model Registry *) module Registry = struct type ('params, 'a, 'dev) model_spec = { architecture : string; config_file : string; weight_files : string list; load_config : Yojson.Safe.t -> 'params; build_params : dtype:(float, 'a) dtype -> 'params -> 'a Kaun.params; } let registry : (string, Obj.t) Hashtbl.t = Hashtbl.create 10 let register name spec = Hashtbl.replace registry name (Obj.repr spec) let get name = try Some (Obj.obj (Hashtbl.find registry name)) with Not_found -> None end (* Utilities *) let ensure_dir dir = let rec mkdir_p path = if not (Sys.file_exists path) then ( mkdir_p (Filename.dirname path); try Unix.mkdir path 0o755 with Unix.Unix_error (Unix.EEXIST, _, _) -> ()) in mkdir_p dir let hub_url ~model_id ~filename ~revision = let revision_str = match revision with | Latest -> "main" | Tag tag -> tag | Commit commit -> commit in Printf.sprintf "https://huggingface.co/%s/resolve/%s/%s" model_id revision_str filename let cache_path config ~model_id ~filename ~revision = let revision_str = match revision with | Latest -> "main" | Tag tag -> "tags/" ^ tag | Commit commit -> "commits/" ^ commit in let model_dir = String.map (fun c -> if c = '/' then '-' else c) model_id in Filename.concat config.Config.cache_dir (Filename.concat model_dir (Filename.concat revision_str filename)) (* Core Loading Functions *) let download_with_progress ~url ~dest ~show_progress = ensure_dir (Filename.dirname dest); (* Use curl with progress bar if requested *) let progress_flag = if show_progress then "" else "-s" in let cmd = Printf.sprintf "curl -L %s -o %s '%s'" progress_flag dest url in if show_progress then Printf.printf "Downloading from %s...\n%!" url; let start_time = Unix.gettimeofday () in match Unix.system cmd with | Unix.WEXITED 0 -> let elapsed = Unix.gettimeofday () -. start_time in let stats = Unix.stat dest in let bytes = stats.Unix.st_size in let rate = float_of_int bytes /. elapsed in { downloaded_bytes = bytes; total_bytes = Some bytes; rate } | _ -> failwith (Printf.sprintf "Failed to download %s" url) let download_file ?(config = Config.default) ?(revision = Latest) ~model_id ~filename () = let local_path = cache_path config ~model_id ~filename ~revision in (* Check if already cached *) if Sys.file_exists local_path && not config.force_download then Cached local_path else if config.offline_mode then failwith (Printf.sprintf "File not in cache (offline mode): %s" local_path) else (* Download the file *) let url = hub_url ~model_id ~filename ~revision in let progress = download_with_progress ~url ~dest:local_path ~show_progress:config.show_progress in Downloaded (local_path, progress) let load_safetensors ?(config = Config.default) ?(revision = Latest) ~model_id ~dtype () = (* Try common safetensors filenames *) let filenames = [ "model.safetensors"; "pytorch_model.safetensors"; "model-00001-of-00001.safetensors"; ] in let rec try_files = function | [] -> failwith (Printf.sprintf "No safetensors file found for %s" model_id) | filename :: rest -> ( try let result = download_file ~config ~revision ~model_id ~filename () in let local_path = match result with | Cached path -> path | Downloaded (path, _) -> path in (* Load using Kaun_checkpoint *) let checkpointer = Kaun.Checkpoint.Checkpointer.create () in let params = Kaun.Checkpoint.Checkpointer.restore_file checkpointer ~path:local_path ~dtype in match result with | Cached _ -> Cached params | Downloaded (_, progress) -> Downloaded (params, progress) with _ -> try_files rest) in try_files filenames let load_config ?(config = Config.default) ?(revision = Latest) ~model_id () = let result = download_file ~config ~revision ~model_id ~filename:"config.json" () in let local_path = match result with Cached path -> path | Downloaded (path, _) -> path in let json = Yojson.Safe.from_file local_path in match result with | Cached _ -> Cached json | Downloaded (_, progress) -> Downloaded (json, progress) (* High-level Model Loading *) let from_pretrained ?(config = Config.default) ?(revision = Latest) ~model_id ~dtype () = (* Load safetensors weights *) match load_safetensors ~config ~revision ~model_id ~dtype () with | Cached params -> params | Downloaded (params, _) -> params (* Utilities *) let list_cached_models ?(config = Config.default) () = if not (Sys.file_exists config.cache_dir) then [] else let entries = Sys.readdir config.cache_dir in Array.to_list entries |> List.filter (fun e -> Sys.is_directory (Filename.concat config.cache_dir e)) |> List.map (fun e -> String.map (fun c -> if c = '-' then '/' else c) e) let clear_cache ?(config = Config.default) ?model_id () = let rec rm_rf path = if Sys.is_directory path then ( let entries = Sys.readdir path in Array.iter (fun entry -> rm_rf (Filename.concat path entry)) entries; Unix.rmdir path) else Sys.remove path in match model_id with | Some id -> let model_dir = String.map (fun c -> if c = '/' then '-' else c) id in let path = Filename.concat config.cache_dir model_dir in if Sys.file_exists path then rm_rf path | None -> if Sys.file_exists config.cache_dir then rm_rf config.cache_dir let get_model_info model_id = let url = Printf.sprintf "https://huggingface.co/api/models/%s" model_id in let cmd = Printf.sprintf "curl -s '%s'" url in let ic = Unix.open_process_in cmd in let rec read_all acc = try let line = input_line ic in read_all (acc ^ line ^ "\n") with End_of_file -> acc in let output = read_all "" in let status = Unix.close_process_in ic in match status with | Unix.WEXITED 0 -> ( try Ok (Yojson.Safe.from_string output) with _ -> Error "Failed to parse JSON response") | _ -> Error "Failed to fetch model info"
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>