package caisar

  1. Overview
  2. Docs

Source file nnet.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
(**************************************************************************)
(*                                                                        *)
(*  This file is part of CAISAR.                                          *)
(*                                                                        *)
(*  Copyright (C) 2023                                                    *)
(*    CEA (Commissariat à l'énergie atomique et aux énergies              *)
(*         alternatives)                                                  *)
(*                                                                        *)
(*  You can redistribute it and/or modify it under the terms of the GNU   *)
(*  Lesser General Public License as published by the Free Software       *)
(*  Foundation, version 2.1.                                              *)
(*                                                                        *)
(*  It is distributed in the hope that it will be useful,                 *)
(*  but WITHOUT ANY WARRANTY; without even the implied warranty of        *)
(*  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the          *)
(*  GNU Lesser General Public License for more details.                   *)
(*                                                                        *)
(*  See the GNU Lesser General Public License version 2.1                 *)
(*  for more details (enclosed in the file licenses/LGPLv2.1).            *)
(*                                                                        *)
(**************************************************************************)

open Base
module Format = Stdlib.Format
module Sys = Stdlib.Sys
module Filename = Stdlib.Filename
module Fun = Stdlib.Fun

type t = {
  n_layers : int;
  n_inputs : int;
  n_outputs : int;
  max_layer_size : int;
  layer_sizes : int list;
  min_input_values : float list option;
  max_input_values : float list option;
  mean_values : (float list * float) option;
  range_values : (float list * float) option;
  weights_biases : float list list;
}

(* NNet format handling. *)

let nnet_format_error s =
  Error (Format.sprintf "NNet format error: %s condition not satisfied." s)

(* Parse a single NNet format line: split line wrt CSV format, and convert each
   string into a number by means of converter [f]. *)
let handle_nnet_line ~f in_channel =
  List.filter_map
    ~f:(fun s -> try Some (f (String.strip s)) with _ -> None)
    (Csv.next in_channel)

(* Skip the header part, ie comments, of the NNet format. *)
let skip_nnet_header filename in_channel =
  let exception End_of_header in
  let pos_in = ref (Stdlib.pos_in in_channel) in
  try
    while true do
      let line = Stdlib.input_line in_channel in
      if not (Str.string_match (Str.regexp "//") line 0)
      then raise End_of_header
      else pos_in := Stdlib.pos_in in_channel
    done;
    assert false
  with
  | End_of_header ->
    (* At this point we have read one line past the header part: seek back. *)
    Stdlib.seek_in in_channel !pos_in;
    Ok ()
  | End_of_file ->
    Error (Format.sprintf "NNet model not found in file '%s'." filename)

(* Retrieve number of layers, inputs, outputs and maximum layer size. *)
let handle_nnet_basic_info in_channel =
  match handle_nnet_line ~f:Int.of_string in_channel with
  | [ n_layers; n_inputs; n_outputs; max_layer_size ] ->
    Ok (n_layers, n_inputs, n_outputs, max_layer_size)
  | _ -> nnet_format_error "second"
  | exception End_of_file -> nnet_format_error "second"

(* Retrieve size of each layer, including inputs and outputs. *)
let handle_nnet_layer_sizes n_layers in_channel =
  try
    let layer_sizes = handle_nnet_line ~f:Int.of_string in_channel in
    if List.length layer_sizes = n_layers + 1
    then Ok layer_sizes
    else nnet_format_error "third"
  with End_of_file -> nnet_format_error "third"

(* Skip unused flag. *)
let handle_nnet_unused_flag in_channel =
  try
    let _ = Csv.next in_channel in
    Ok ()
  with End_of_file -> nnet_format_error "forth"

(* Retrive minimum values of inputs. *)
let handle_nnet_min_input_values n_inputs in_channel =
  try
    let min_input_values = handle_nnet_line ~f:Float.of_string in_channel in
    if List.length min_input_values = n_inputs
    then Ok min_input_values
    else nnet_format_error "fifth"
  with End_of_file -> nnet_format_error "fifth"

(* Retrive maximum values of inputs. *)
let handle_nnet_max_input_values n_inputs in_channel =
  try
    let max_input_values = handle_nnet_line ~f:Float.of_string in_channel in
    if List.length max_input_values = n_inputs
    then Ok max_input_values
    else nnet_format_error "sixth"
  with End_of_file -> nnet_format_error "sixth"

(* Retrieve mean values of inputs and one value for all outputs. *)
let handle_nnet_mean_values n_inputs in_channel =
  try
    let mean_values = handle_nnet_line ~f:Float.of_string in_channel in
    if List.length mean_values = n_inputs + 1
    then
      let mean_input_values, mean_output_value =
        List.split_n mean_values n_inputs
      in
      Ok (mean_input_values, List.hd_exn mean_output_value)
    else nnet_format_error "seventh"
  with End_of_file -> nnet_format_error "seventh"

(* Retrieve range values of inputs and one value for all outputs. *)
let handle_nnet_range_values n_inputs in_channel =
  try
    let range_values = handle_nnet_line ~f:Float.of_string in_channel in
    if List.length range_values = n_inputs + 1
    then
      let range_input_values, range_output_value =
        List.split_n range_values n_inputs
      in
      Ok (range_input_values, List.hd_exn range_output_value)
    else nnet_format_error "eighth"
  with End_of_file -> nnet_format_error "eighth"

(* Retrieve all layer weights and biases as appearing in the model. No special
   treatment is performed. *)
let handle_nnet_weights_and_biases in_channel =
  List.rev
    (Csv.fold_left ~init:[]
       ~f:(fun fll sl ->
         List.filter_map
           ~f:(fun s ->
             try Some (Float.of_string (String.strip s)) with _ -> None)
           sl
         :: fll)
       in_channel)

(* Retrieves [filename] NNet model metadata and weights wrt NNet format
   specification (see https://github.com/sisl/NNet for details). *)
let parse_in_channel ?(permissive = false) filename in_channel =
  let open Result in
  let ok_opt r =
    match r with
    | Ok x -> Ok (Some x)
    | Error _ as error -> if not permissive then error else Ok None
  in
  try
    skip_nnet_header filename in_channel >>= fun () ->
    let in_channel = Csv.of_channel in_channel in
    handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) ->
    handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes ->
    handle_nnet_unused_flag in_channel >>= fun () ->
    ok_opt (handle_nnet_min_input_values n_is in_channel)
    >>= fun min_input_values ->
    ok_opt (handle_nnet_max_input_values n_is in_channel)
    >>= fun max_input_values ->
    ok_opt (handle_nnet_mean_values n_is in_channel) >>= fun mean_values ->
    ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values ->
    let weights_biases = handle_nnet_weights_and_biases in_channel in
    Csv.close_in in_channel;
    Ok
      {
        n_layers = n_ls;
        n_inputs = n_is;
        n_outputs = n_os;
        max_layer_size = max_l_size;
        layer_sizes;
        min_input_values;
        max_input_values;
        mean_values;
        range_values;
        weights_biases;
      }
  with
  | Csv.Failure (_nrecord, _nfield, msg) -> Error msg
  | Sys_error s -> Error s
  | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg)

let parse ?(permissive = false) filename =
  let in_channel = Stdlib.open_in filename in
  Fun.protect
    ~finally:(fun () -> Stdlib.close_in in_channel)
    (fun () -> parse_in_channel ~permissive filename in_channel)
OCaml

Innovation. Community. Security.