package neural_nets_lib

  1. Overview
  2. Docs

Source file mnist.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
(* mnist.ml *)
open Bigarray
open Dataset_utils

(* Config remains the same *)
module Config = struct
  type t = {
    name : string;
    cache_subdir : string;
    train_images_url : string;
    train_labels_url : string;
    test_images_url : string;
    test_labels_url : string;
    image_magic_number : int;
    label_magic_number : int;
  }

  let mnist =
    {
      name = "MNIST";
      cache_subdir = "mnist/";
      train_images_url = "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz";
      train_labels_url = "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz";
      test_images_url = "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz";
      test_labels_url = "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz";
      image_magic_number = 2051;
      label_magic_number = 2049;
    }

  let fashion_mnist =
    {
      name = "Fashion-MNIST";
      cache_subdir = "fashion-mnist/";
      train_images_url =
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz";
      train_labels_url =
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz";
      test_images_url =
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz";
      test_labels_url =
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz";
      image_magic_number = 2051;
      label_magic_number = 2049;
    }
end

let mnist_config = Config.mnist
let fashion_mnist_config = Config.fashion_mnist

(* IDX parsing logic remains specific and local *)
let read_int32_be s pos =
  let b1 = Char.code s.[pos] in
  let b2 = Char.code s.[pos + 1] in
  let b3 = Char.code s.[pos + 2] in
  let b4 = Char.code s.[pos + 3] in
  (b1 lsl 24) lor (b2 lsl 16) lor (b3 lsl 8) lor b4

let ensure_dataset config =
  let dataset_dir = get_cache_dir config.Config.cache_subdir in
  mkdir_p dataset_dir;

  (* Ensure base dir exists *)
  let files_to_process =
    [
      ("train-images-idx3-ubyte", config.Config.train_images_url);
      ("train-labels-idx1-ubyte", config.Config.train_labels_url);
      ("t10k-images-idx3-ubyte", config.Config.test_images_url);
      ("t10k-labels-idx1-ubyte", config.Config.test_labels_url);
    ]
  in
  List.iter
    (fun (base_filename, url) ->
      let gz_filename = base_filename ^ ".gz" in
      let gz_path = dataset_dir ^ gz_filename in
      let path = dataset_dir ^ base_filename in

      if not (Sys.file_exists path) then (
        Printf.printf "File %s not found for %s dataset.\n%!" base_filename config.name;
        (* Ensure the .gz file is downloaded *)
        ensure_file url gz_path;
        (* Ensure it's decompressed *)
        if not (ensure_decompressed_gz ~gz_path ~target_path:path) then
          failwith (Printf.sprintf "Failed to obtain decompressed file %s" path))
      else Printf.printf "Found decompressed file %s.\n%!" path)
    files_to_process

let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic config filename =
  Printf.printf "Reading %s file: %s\n%!" config.Config.name filename;
  let ic = open_in_bin filename in
  let s =
    try really_input_string ic (in_channel_length ic)
    with exn ->
      close_in_noerr ic;
      failwith (Printf.sprintf "Error reading file %s: %s" filename (Printexc.to_string exn))
  in
  close_in ic;

  let magic = read_int32_be s 0 in
  if magic <> expected_magic then
    failwith
      (Printf.sprintf "Invalid magic number %d in %s (expected %d)" magic filename expected_magic);

  let dimensions, data_offset = read_header s in
  let total_items, data_len =
    match dimensions with
    | [| d1 |] -> (d1, d1)
    | [| d1; d2; d3 |] -> (d1, d1 * d2 * d3)
    | _ -> failwith "Unsupported dimension format"
  in
  let expected_len = data_offset + data_len in
  if String.length s <> expected_len then
    failwith
      (Printf.sprintf "File %s has unexpected length: %d vs %d (header offset %d, data len %d)"
         filename (String.length s) expected_len data_offset data_len);

  let arr = create_array dimensions in
  populate_array arr s data_offset total_items;
  arr

(* read_images and read_labels remain largely the same, just use the config passed in *)
let read_images config filename =
  let read_header s =
    let num_images = read_int32_be s 4 in
    let num_rows = read_int32_be s 8 in
    let num_cols = read_int32_be s 12 in
    ([| num_images; num_rows; num_cols |], 16)
  in
  let create_array dims = Genarray.create int8_unsigned c_layout dims in
  let populate_array arr s offset _ =
    let dims = Genarray.dims arr in
    let num_images = dims.(0) in
    let num_rows = dims.(1) in
    let num_cols = dims.(2) in
    let img_size = num_rows * num_cols in
    for i = 0 to num_images - 1 do
      let start_pos = offset + (i * img_size) in
      for r = 0 to num_rows - 1 do
        for c = 0 to num_cols - 1 do
          let pos = start_pos + (r * num_cols) + c in
          Genarray.set arr [| i; r; c |] (Char.code s.[pos])
        done
      done
    done
  in
  read_idx_file ~read_header ~create_array ~populate_array
    ~expected_magic:config.Config.image_magic_number config filename

let read_labels config filename =
  let read_header s =
    let num_labels = read_int32_be s 4 in
    ([| num_labels |], 8)
  in
  let create_array dims = Genarray.create int8_unsigned c_layout dims in
  let populate_array arr s offset total_items =
    for i = 0 to total_items - 1 do
      Genarray.set arr [| i |] (Char.code s.[offset + i])
    done
  in
  read_idx_file ~read_header ~create_array ~populate_array
    ~expected_magic:config.Config.label_magic_number config filename

let load ~fashion_mnist =
  let config = if fashion_mnist then Config.fashion_mnist else Config.mnist in
  ensure_dataset config;

  let dataset_dir = get_cache_dir config.Config.cache_subdir in
  let train_images_path = dataset_dir ^ "train-images-idx3-ubyte" in
  let train_labels_path = dataset_dir ^ "train-labels-idx1-ubyte" in
  let test_images_path = dataset_dir ^ "t10k-images-idx3-ubyte" in
  let test_labels_path = dataset_dir ^ "t10k-labels-idx1-ubyte" in

  Printf.printf "Loading %s datasets...\n%!" config.name;
  let train_images = read_images config train_images_path in
  let train_labels = read_labels config train_labels_path in
  let test_images = read_images config test_images_path in
  let test_labels = read_labels config test_labels_path in
  Printf.printf "%s loading complete.\n%!" config.name;
  ((train_images, train_labels), (test_images, test_labels))
OCaml

Innovation. Community. Security.