package kaun

  1. Overview
  2. Docs
Flax-inspired neural network library for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c

doc/src/kaun.models/lenet.ml.html

Source file lenet.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
open Rune

(* Configuration *)

type config = {
  num_classes : int;
  input_channels : int;
  input_size : int * int;
  activation : [ `tanh | `relu | `sigmoid ];
  dropout_rate : float option;
}

let default_config =
  {
    num_classes = 10;
    input_channels = 1;
    input_size = (32, 32);
    (* Original LeNet-5 uses 32x32 *)
    activation = `tanh;
    (* Original used tanh *)
    dropout_rate = None;
  }

let mnist_config =
  {
    default_config with
    input_size = (28, 28);
    (* MNIST is 28x28, will be padded *)
  }

let cifar10_config =
  {
    num_classes = 10;
    input_channels = 3;
    (* RGB *)
    input_size = (32, 32);
    activation = `relu;
    (* Modern choice *)
    dropout_rate = Some 0.5;
  }

(* Model Definition *)

type t = Kaun.module_

let create ?(config = default_config) () =
  let open Kaun.Layer in
  (* Select activation function *)
  let activation_fn =
    match config.activation with
    | `tanh -> tanh ()
    | `relu -> relu ()
    | `sigmoid -> sigmoid ()
  in

  (* Build layers *)
  let layers =
    [
      (* First convolutional block *)
      (* Conv1: 6 filters of 5x5 *)
      conv2d ~in_channels:config.input_channels ~out_channels:6
        ~kernel_size:(5, 5) ();
      activation_fn;
      (* Pool1: 2x2 average pooling (original used average, modern uses max) *)
      avg_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) ();
      (* Second convolutional block *)
      (* Conv2: 16 filters of 5x5 *)
      conv2d ~in_channels:6 ~out_channels:16 ~kernel_size:(5, 5) ();
      activation_fn;
      (* Pool2: 2x2 average pooling *)
      avg_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) ();
      (* Flatten for fully connected layers *)
      flatten ();
      (* Fully connected layers *)
      (* The size after conv2 and pool2 depends on input size *)
      (* For 32x32 input: after conv1+pool1: 14x14, after conv2+pool2: 5x5 *)
      (* So flattened size is 16 * 5 * 5 = 400 *)
      (* For 28x28 input: after conv1+pool1: 12x12, after conv2+pool2: 4x4 *)
      (* So flattened size is 16 * 4 * 4 = 256 *)

      (* FC1: 120 units *)
      linear ~in_features:400 ~out_features:120 ();
      (* Assuming 32x32 input *)
      activation_fn;
    ]
    (* Add optional dropout *)
    @ (match config.dropout_rate with
      | Some rate -> [ dropout ~rate () ]
      | None -> [])
    @ [
        (* FC2: 84 units *)
        linear ~in_features:120 ~out_features:84 ();
        activation_fn;
      ]
    (* Add optional dropout *)
    @ (match config.dropout_rate with
      | Some rate -> [ dropout ~rate () ]
      | None -> [])
    @ [
        (* Output layer *)
        linear ~in_features:84 ~out_features:config.num_classes ();
        (* No activation for logits output *)
      ]
  in

  sequential layers

let for_mnist () = create ~config:mnist_config ()
let for_cifar10 () = create ~config:cifar10_config ()

(* Forward Pass *)

let forward ~model ~params ~training ~input =
  Kaun.apply model params ~training input

let extract_features ~model:_ ~params:_ ~input:_ =
  (* Simplified version - would need to modify model to extract intermediate features *)
  (* For now, just return a dummy tensor *)
  failwith "extract_features not implemented yet"

(* Model Statistics *)

let num_parameters params =
  let tensors = Kaun.Ptree.flatten_with_paths params in
  List.fold_left
    (fun acc (_, t) -> acc + Array.fold_left ( * ) 1 (shape t))
    0 tensors

let parameter_breakdown params =
  let tensors = Kaun.Ptree.flatten_with_paths params in
  let breakdown = Buffer.create 256 in
  Buffer.add_string breakdown "LeNet-5 Parameter Breakdown:\n";
  Buffer.add_string breakdown "============================\n";

  let layer_params = Hashtbl.create 10 in

  (* Group parameters by layer *)
  List.iter
    (fun (name, tensor) ->
      let layer_name =
        (* Extract layer name from parameter path *)
        try
          let idx = String.index name '.' in
          String.sub name 0 idx
        with Not_found -> name
      in
      let size = Array.fold_left ( * ) 1 (shape tensor) in
      let current =
        try Hashtbl.find layer_params layer_name with Not_found -> 0
      in
      Hashtbl.replace layer_params layer_name (current + size))
    tensors;

  (* Print breakdown *)
  Hashtbl.iter
    (fun layer count ->
      Buffer.add_string breakdown
        (Printf.sprintf "  %s: %d parameters\n" layer count))
    layer_params;

  let total = num_parameters params in
  Buffer.add_string breakdown
    (Printf.sprintf "\nTotal: %d parameters (%.2f MB with float32)\n" total
       (float_of_int (total * 4) /. 1024. /. 1024.));

  Buffer.contents breakdown

(* Training Helpers *)

type train_config = {
  learning_rate : float;
  batch_size : int;
  num_epochs : int;
  weight_decay : float option;
  momentum : float option;
}

let default_train_config =
  {
    learning_rate = 0.01;
    batch_size = 64;
    num_epochs = 10;
    weight_decay = Some 0.0001;
    momentum = Some 0.9;
  }

let accuracy ~predictions ~labels =
  (* Get predicted classes *)
  let pred_classes = argmax predictions ~axis:1 in
  (* Cast labels to same type for comparison *)
  let labels_int32 = cast Int32 labels in
  (* Compute accuracy *)
  let correct = equal pred_classes labels_int32 in
  let correct_float = cast Float32 correct in
  let total = float_of_int (Array.get (shape labels) 0) in
  (* Sum and convert to float *)
  let num_correct = sum correct_float in
  let num_correct_scalar =
    (* Extract scalar value - simplified version *)
    match shape num_correct with
    | [||] ->
        (* It's already a scalar, extract the value *)
        let arr = to_array num_correct in
        arr.(0)
    | _ -> failwith "Expected scalar result from sum"
  in
  num_correct_scalar /. total