package neural_nets_lib

  1. Overview
  2. Docs
A from-scratch Deep Learning framework with an optimizing compiler, shape inference, concise syntax

Install

dune-project
 Dependency

Authors

Maintainers

Sources

0.6.0.4.tar.gz
md5=5beaaa0b377bec3badffffbf9f4dec4a
sha512=a37a67452746143f0f5ba2e81f98d6fed31fb4397e0a85f4a35aedc805b4e0405ea89d465c6f80941c465fb61d5d6119806cb73b5c5ead925797eb80d19c5ade

doc/src/neural_nets_lib.datasets/cifar10.ml.html

Source file cifar10.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
(* cifar10.ml *)
open Bigarray
open Dataset_utils

let dataset_name = "cifar-10"
let base_dir = get_cache_dir dataset_name
let archive_dir_name = "cifar-10-batches-py"
let dataset_dir = base_dir ^ archive_dir_name ^ "/"
let url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
let tar_path = base_dir ^ Filename.basename url
let test_batch_rel_path = archive_dir_name ^ "/test_batch"

let ensure_dataset () =
  ensure_extracted_archive ~url ~archive_path:tar_path ~extract_dir:base_dir
    ~check_file:test_batch_rel_path

let read_cifar_batch filename =
  Printf.printf "Reading batch file: %s\n%!" filename;
  let ic = open_in_bin filename in
  let s =
    try really_input_string ic (in_channel_length ic)
    with exn ->
      close_in_noerr ic;
      raise exn
  in
  close_in ic;

  let num_bytes = String.length s in
  let bytes_per_image = 3073 in
  if num_bytes mod bytes_per_image <> 0 then
    failwith (Printf.sprintf "File %s has unexpected size %d" filename num_bytes);

  let num_images = num_bytes / bytes_per_image in
  Printf.printf "Found %d images in %s.\n%!" num_images filename;

  let images = Genarray.create int8_unsigned c_layout [| num_images; 32; 32; 3 |] in
  let labels = Genarray.create int8_unsigned c_layout [| num_images |] in

  for i = 0 to num_images - 1 do
    let base_offset = i * bytes_per_image in
    Genarray.set labels [| i |] (Char.code s.[base_offset]);
    let r_offset = base_offset + 1 in
    let g_offset = r_offset + 1024 in
    let b_offset = g_offset + 1024 in
    for row = 0 to 31 do
      for col = 0 to 31 do
        let plane_idx = (row * 32) + col in
        Genarray.set images [| i; row; col; 0 |] (Char.code s.[r_offset + plane_idx]);
        (* Red *)
        Genarray.set images [| i; row; col; 1 |] (Char.code s.[g_offset + plane_idx]);
        (* Green *)
        Genarray.set images [| i; row; col; 2 |] (Char.code s.[b_offset + plane_idx])
        (* Blue *)
      done
    done
  done;
  (images, labels)

let load () =
  ensure_dataset ();
  Printf.printf "Loading CIFAR-10 dataset...\n%!";

  let train_batches_files =
    List.init 5 (fun i -> dataset_dir ^ Printf.sprintf "data_batch_%d" (i + 1))
  in
  let train_batches_data = List.map read_cifar_batch train_batches_files in

  let total_train_images = 50000 in
  (* Create the final training Genarray *)
  let train_images = Genarray.create int8_unsigned c_layout [| total_train_images; 32; 32; 3 |] in
  let train_labels = Genarray.create int8_unsigned c_layout [| total_train_images |] in

  let current_offset = ref 0 in
  List.iter
    (fun (batch_images, batch_labels) ->
      let batch_size = (Genarray.dims batch_labels).(0) in
      let img_slice_dims = [| batch_size; 32; 32; 3 |] in
      let img_slice = Genarray.sub_left train_images !current_offset batch_size in
      (* Ensure the slice has the expected dimensions before blitting *)
      if Genarray.dims img_slice <> img_slice_dims then
        failwith
          (Printf.sprintf
             "Internal error: train image slice dimension mismatch (expected %s, got %s)"
             (String.concat "x" (Array.to_list (Array.map string_of_int img_slice_dims)))
             (String.concat "x" (Array.to_list (Array.map string_of_int (Genarray.dims img_slice)))));

      let lbl_slice = Genarray.sub_left train_labels !current_offset batch_size in
      Genarray.blit batch_images img_slice;
      Genarray.blit batch_labels lbl_slice;
      current_offset := !current_offset + batch_size)
    train_batches_data;

  let test_batch_file = dataset_dir ^ "test_batch" in
  let test_images, test_labels = read_cifar_batch test_batch_file in

  Printf.printf "CIFAR-10 loading complete.\n%!";
  ((train_images, train_labels), (test_images, test_labels))
OCaml

Innovation. Community. Security.