package kaun
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>
Neural networks for OCaml
Install
dune-project
Dependency
Authors
Maintainers
Sources
raven-1.0.0.alpha3.tbz
sha256=96d35ce03dfbebd2313657273e24c2e2d20f9e6c7825b8518b69bd1d6ed5870f
sha512=90c5053731d4108f37c19430e45456063e872b04b8a1bbad064c356e1b18e69222de8bfcf4ec14757e71f18164ec6e4630ba770dbcb1291665de5418827d1465
doc/src/kaun.datasets/mnist.ml.html
Source file mnist.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178(*--------------------------------------------------------------------------- Copyright (c) 2026 The Raven authors. All rights reserved. SPDX-License-Identifier: ISC ---------------------------------------------------------------------------*) open Bigarray open Dataset_utils let src = Logs.Src.create "kaun.datasets.mnist" ~doc:"MNIST dataset loader" module Log = (val Logs.src_log src : Logs.LOG) module Config = struct type t = { name : string; cache_subdir : string; train_images_url : string; train_labels_url : string; test_images_url : string; test_labels_url : string; image_magic_number : int; label_magic_number : int; } let mnist = { name = "MNIST"; cache_subdir = "mnist/"; train_images_url = "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz"; train_labels_url = "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz"; test_images_url = "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz"; test_labels_url = "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz"; image_magic_number = 2051; label_magic_number = 2049; } let fashion_mnist = { name = "Fashion-MNIST"; cache_subdir = "fashion-mnist/"; train_images_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz"; train_labels_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz"; test_images_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz"; test_labels_url = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz"; image_magic_number = 2051; label_magic_number = 2049; } end let read_int32_be s pos = let b1 = Char.code s.[pos] in let b2 = Char.code s.[pos + 1] in let b3 = Char.code s.[pos + 2] in let b4 = Char.code s.[pos + 3] in (b1 lsl 24) lor (b2 lsl 16) lor (b3 lsl 8) lor b4 let ensure_dataset config = let dataset_dir = get_cache_dir config.Config.cache_subdir in mkdir_p dataset_dir; let files_to_process = [ ("train-images-idx3-ubyte", config.Config.train_images_url); ("train-labels-idx1-ubyte", config.Config.train_labels_url); ("t10k-images-idx3-ubyte", config.Config.test_images_url); ("t10k-labels-idx1-ubyte", config.Config.test_labels_url); ] in List.iter (fun (base_filename, url) -> let gz_filename = base_filename ^ ".gz" in let gz_path = dataset_dir ^ gz_filename in let path = dataset_dir ^ base_filename in if not (Sys.file_exists path) then ( Log.debug (fun m -> m "File %s not found for %s dataset" base_filename config.name); ensure_file url gz_path; if not (ensure_decompressed_gz ~gz_path ~target_path:path) then failwith (Printf.sprintf "Failed to obtain decompressed file %s" path)) else Log.debug (fun m -> m "Found decompressed file %s" path)) files_to_process let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic config filename = Log.debug (fun m -> m "Reading %s file: %s" config.Config.name filename); let ic = open_in_bin filename in let s = Fun.protect ~finally:(fun () -> close_in ic) (fun () -> really_input_string ic (in_channel_length ic)) in let magic = read_int32_be s 0 in if magic <> expected_magic then failwith (Printf.sprintf "Invalid magic number %d in %s (expected %d)" magic filename expected_magic); let dimensions, data_offset = read_header s in let total_items, data_len = match dimensions with | [| d1 |] -> (d1, d1) | [| d1; d2; d3 |] -> (d1, d1 * d2 * d3) | _ -> failwith "Unsupported dimension format" in let expected_len = data_offset + data_len in if String.length s <> expected_len then failwith (Printf.sprintf "File %s has unexpected length: %d vs %d (header offset %d, data len \ %d)" filename (String.length s) expected_len data_offset data_len); let arr = create_array dimensions in populate_array arr s data_offset total_items; arr let read_images config filename = let read_header s = let num_images = read_int32_be s 4 in let num_rows = read_int32_be s 8 in let num_cols = read_int32_be s 12 in ([| num_images; num_rows; num_cols |], 16) in let create_array dims = Array3.create int8_unsigned c_layout dims.(0) dims.(1) dims.(2) in let populate_array arr s offset _total_items = let num_images = Array3.dim1 arr in let num_rows = Array3.dim2 arr in let num_cols = Array3.dim3 arr in let img_size = num_rows * num_cols in for i = 0 to num_images - 1 do let start_pos = offset + (i * img_size) in for r = 0 to num_rows - 1 do for c = 0 to num_cols - 1 do let pos = start_pos + (r * num_cols) + c in arr.{i, r, c} <- Char.code s.[pos] done done done in read_idx_file ~read_header ~create_array ~populate_array ~expected_magic:config.Config.image_magic_number config filename let read_labels config filename = let read_header s = let num_labels = read_int32_be s 4 in ([| num_labels |], 8) in let create_array dims = Array1.create int8_unsigned c_layout dims.(0) in let populate_array arr s offset total_items = for i = 0 to total_items - 1 do arr.{i} <- Char.code s.[offset + i] done in read_idx_file ~read_header ~create_array ~populate_array ~expected_magic:config.Config.label_magic_number config filename let load ~fashion_mnist = let config = if fashion_mnist then Config.fashion_mnist else Config.mnist in ensure_dataset config; let dataset_dir = get_cache_dir config.Config.cache_subdir in let train_images_path = dataset_dir ^ "train-images-idx3-ubyte" in let train_labels_path = dataset_dir ^ "train-labels-idx1-ubyte" in let test_images_path = dataset_dir ^ "t10k-images-idx3-ubyte" in let test_labels_path = dataset_dir ^ "t10k-labels-idx1-ubyte" in Log.info (fun m -> m "Loading %s datasets..." config.name); let train_images = read_images config train_images_path in let train_labels = read_labels config train_labels_path in let test_images = read_images config test_images_path in let test_labels = read_labels config test_labels_path in Log.info (fun m -> m "%s loading complete" config.name); ((train_images, train_labels), (test_images, test_labels))
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>