package kaun

  1. Overview
  2. Docs

Source file mnist.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
(*---------------------------------------------------------------------------
  Copyright (c) 2026 The Raven authors. All rights reserved.
  SPDX-License-Identifier: ISC
  ---------------------------------------------------------------------------*)

open Bigarray
open Dataset_utils

let src = Logs.Src.create "kaun.datasets.mnist" ~doc:"MNIST dataset loader"

module Log = (val Logs.src_log src : Logs.LOG)

module Config = struct
  type t = {
    name : string;
    cache_subdir : string;
    train_images_url : string;
    train_labels_url : string;
    test_images_url : string;
    test_labels_url : string;
    image_magic_number : int;
    label_magic_number : int;
  }

  let mnist =
    {
      name = "MNIST";
      cache_subdir = "mnist/";
      train_images_url =
        "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz";
      train_labels_url =
        "https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz";
      test_images_url =
        "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz";
      test_labels_url =
        "https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz";
      image_magic_number = 2051;
      label_magic_number = 2049;
    }

  let fashion_mnist =
    {
      name = "Fashion-MNIST";
      cache_subdir = "fashion-mnist/";
      train_images_url =
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz";
      train_labels_url =
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz";
      test_images_url =
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz";
      test_labels_url =
        "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz";
      image_magic_number = 2051;
      label_magic_number = 2049;
    }
end

let read_int32_be s pos =
  let b1 = Char.code s.[pos] in
  let b2 = Char.code s.[pos + 1] in
  let b3 = Char.code s.[pos + 2] in
  let b4 = Char.code s.[pos + 3] in
  (b1 lsl 24) lor (b2 lsl 16) lor (b3 lsl 8) lor b4

let ensure_dataset config =
  let dataset_dir = get_cache_dir config.Config.cache_subdir in
  mkdir_p dataset_dir;
  let files_to_process =
    [
      ("train-images-idx3-ubyte", config.Config.train_images_url);
      ("train-labels-idx1-ubyte", config.Config.train_labels_url);
      ("t10k-images-idx3-ubyte", config.Config.test_images_url);
      ("t10k-labels-idx1-ubyte", config.Config.test_labels_url);
    ]
  in
  List.iter
    (fun (base_filename, url) ->
      let gz_filename = base_filename ^ ".gz" in
      let gz_path = dataset_dir ^ gz_filename in
      let path = dataset_dir ^ base_filename in
      if not (Sys.file_exists path) then (
        Log.debug (fun m ->
            m "File %s not found for %s dataset" base_filename config.name);
        ensure_file url gz_path;
        if not (ensure_decompressed_gz ~gz_path ~target_path:path) then
          failwith (Printf.sprintf "Failed to obtain decompressed file %s" path))
      else Log.debug (fun m -> m "Found decompressed file %s" path))
    files_to_process

let read_idx_file ~read_header ~create_array ~populate_array ~expected_magic
    config filename =
  Log.debug (fun m -> m "Reading %s file: %s" config.Config.name filename);
  let ic = open_in_bin filename in
  let s =
    Fun.protect
      ~finally:(fun () -> close_in ic)
      (fun () -> really_input_string ic (in_channel_length ic))
  in
  let magic = read_int32_be s 0 in
  if magic <> expected_magic then
    failwith
      (Printf.sprintf "Invalid magic number %d in %s (expected %d)" magic
         filename expected_magic);
  let dimensions, data_offset = read_header s in
  let total_items, data_len =
    match dimensions with
    | [| d1 |] -> (d1, d1)
    | [| d1; d2; d3 |] -> (d1, d1 * d2 * d3)
    | _ -> failwith "Unsupported dimension format"
  in
  let expected_len = data_offset + data_len in
  if String.length s <> expected_len then
    failwith
      (Printf.sprintf
         "File %s has unexpected length: %d vs %d (header offset %d, data len \
          %d)"
         filename (String.length s) expected_len data_offset data_len);
  let arr = create_array dimensions in
  populate_array arr s data_offset total_items;
  arr

let read_images config filename =
  let read_header s =
    let num_images = read_int32_be s 4 in
    let num_rows = read_int32_be s 8 in
    let num_cols = read_int32_be s 12 in
    ([| num_images; num_rows; num_cols |], 16)
  in
  let create_array dims =
    Array3.create int8_unsigned c_layout dims.(0) dims.(1) dims.(2)
  in
  let populate_array arr s offset _total_items =
    let num_images = Array3.dim1 arr in
    let num_rows = Array3.dim2 arr in
    let num_cols = Array3.dim3 arr in
    let img_size = num_rows * num_cols in
    for i = 0 to num_images - 1 do
      let start_pos = offset + (i * img_size) in
      for r = 0 to num_rows - 1 do
        for c = 0 to num_cols - 1 do
          let pos = start_pos + (r * num_cols) + c in
          arr.{i, r, c} <- Char.code s.[pos]
        done
      done
    done
  in
  read_idx_file ~read_header ~create_array ~populate_array
    ~expected_magic:config.Config.image_magic_number config filename

let read_labels config filename =
  let read_header s =
    let num_labels = read_int32_be s 4 in
    ([| num_labels |], 8)
  in
  let create_array dims = Array1.create int8_unsigned c_layout dims.(0) in
  let populate_array arr s offset total_items =
    for i = 0 to total_items - 1 do
      arr.{i} <- Char.code s.[offset + i]
    done
  in
  read_idx_file ~read_header ~create_array ~populate_array
    ~expected_magic:config.Config.label_magic_number config filename

let load ~fashion_mnist =
  let config = if fashion_mnist then Config.fashion_mnist else Config.mnist in
  ensure_dataset config;
  let dataset_dir = get_cache_dir config.Config.cache_subdir in
  let train_images_path = dataset_dir ^ "train-images-idx3-ubyte" in
  let train_labels_path = dataset_dir ^ "train-labels-idx1-ubyte" in
  let test_images_path = dataset_dir ^ "t10k-images-idx3-ubyte" in
  let test_labels_path = dataset_dir ^ "t10k-labels-idx1-ubyte" in
  Log.info (fun m -> m "Loading %s datasets..." config.name);
  let train_images = read_images config train_images_path in
  let train_labels = read_labels config train_labels_path in
  let test_images = read_images config test_images_path in
  let test_labels = read_labels config test_labels_path in
  Log.info (fun m -> m "%s loading complete" config.name);
  ((train_images, train_labels), (test_images, test_labels))