package caisar
A platform for characterizing the safety and robustness of artificial intelligence based software
Install
dune-project
Dependency
Authors
Maintainers
Sources
caisar-0.2.1.tbz
sha256=a9a704f1e4e255eee2e9b0333e6c7b0e3e002293ce0068faa1c3d7c18d209997
sha512=7e35bd5527f82c5c6f62452c88e2971907a4eab89fd4efb699b99eb95f730d752908d51c47e104dcff5ceb58cf24c87d3399cb42e09a47691440927463168abb
doc/src/caisar.nnet/nnet.ml.html
Source file nnet.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
(**************************************************************************) (* *) (* This file is part of CAISAR. *) (* *) (* Copyright (C) 2023 *) (* CEA (Commissariat à l'énergie atomique et aux énergies *) (* alternatives) *) (* *) (* You can redistribute it and/or modify it under the terms of the GNU *) (* Lesser General Public License as published by the Free Software *) (* Foundation, version 2.1. *) (* *) (* It is distributed in the hope that it will be useful, *) (* but WITHOUT ANY WARRANTY; without even the implied warranty of *) (* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *) (* GNU Lesser General Public License for more details. *) (* *) (* See the GNU Lesser General Public License version 2.1 *) (* for more details (enclosed in the file licenses/LGPLv2.1). *) (* *) (**************************************************************************) open Base module Format = Stdlib.Format module Sys = Stdlib.Sys module Filename = Stdlib.Filename module Fun = Stdlib.Fun type t = { n_layers : int; n_inputs : int; n_outputs : int; max_layer_size : int; layer_sizes : int list; min_input_values : float list option; max_input_values : float list option; mean_values : (float list * float) option; range_values : (float list * float) option; weights_biases : float list list; } (* NNet format handling. *) let nnet_format_error s = Error (Format.sprintf "NNet format error: %s condition not satisfied." s) (* Parse a single NNet format line: split line wrt CSV format, and convert each string into a number by means of converter [f]. *) let handle_nnet_line ~f in_channel = List.filter_map ~f:(fun s -> try Some (f (String.strip s)) with _ -> None) (Csv.next in_channel) (* Skip the header part, ie comments, of the NNet format. *) let skip_nnet_header filename in_channel = let exception End_of_header in let pos_in = ref (Stdlib.pos_in in_channel) in try while true do let line = Stdlib.input_line in_channel in if not (Str.string_match (Str.regexp "//") line 0) then raise End_of_header else pos_in := Stdlib.pos_in in_channel done; assert false with | End_of_header -> (* At this point we have read one line past the header part: seek back. *) Stdlib.seek_in in_channel !pos_in; Ok () | End_of_file -> Error (Format.sprintf "NNet model not found in file '%s'." filename) (* Retrieve number of layers, inputs, outputs and maximum layer size. *) let handle_nnet_basic_info in_channel = match handle_nnet_line ~f:Int.of_string in_channel with | [ n_layers; n_inputs; n_outputs; max_layer_size ] -> Ok (n_layers, n_inputs, n_outputs, max_layer_size) | _ -> nnet_format_error "second" | exception End_of_file -> nnet_format_error "second" (* Retrieve size of each layer, including inputs and outputs. *) let handle_nnet_layer_sizes n_layers in_channel = try let layer_sizes = handle_nnet_line ~f:Int.of_string in_channel in if List.length layer_sizes = n_layers + 1 then Ok layer_sizes else nnet_format_error "third" with End_of_file -> nnet_format_error "third" (* Skip unused flag. *) let handle_nnet_unused_flag in_channel = try let _ = Csv.next in_channel in Ok () with End_of_file -> nnet_format_error "forth" (* Retrive minimum values of inputs. *) let handle_nnet_min_input_values n_inputs in_channel = try let min_input_values = handle_nnet_line ~f:Float.of_string in_channel in if List.length min_input_values = n_inputs then Ok min_input_values else nnet_format_error "fifth" with End_of_file -> nnet_format_error "fifth" (* Retrive maximum values of inputs. *) let handle_nnet_max_input_values n_inputs in_channel = try let max_input_values = handle_nnet_line ~f:Float.of_string in_channel in if List.length max_input_values = n_inputs then Ok max_input_values else nnet_format_error "sixth" with End_of_file -> nnet_format_error "sixth" (* Retrieve mean values of inputs and one value for all outputs. *) let handle_nnet_mean_values n_inputs in_channel = try let mean_values = handle_nnet_line ~f:Float.of_string in_channel in if List.length mean_values = n_inputs + 1 then let mean_input_values, mean_output_value = List.split_n mean_values n_inputs in Ok (mean_input_values, List.hd_exn mean_output_value) else nnet_format_error "seventh" with End_of_file -> nnet_format_error "seventh" (* Retrieve range values of inputs and one value for all outputs. *) let handle_nnet_range_values n_inputs in_channel = try let range_values = handle_nnet_line ~f:Float.of_string in_channel in if List.length range_values = n_inputs + 1 then let range_input_values, range_output_value = List.split_n range_values n_inputs in Ok (range_input_values, List.hd_exn range_output_value) else nnet_format_error "eighth" with End_of_file -> nnet_format_error "eighth" (* Retrieve all layer weights and biases as appearing in the model. No special treatment is performed. *) let handle_nnet_weights_and_biases in_channel = List.rev (Csv.fold_left ~init:[] ~f:(fun fll sl -> List.filter_map ~f:(fun s -> try Some (Float.of_string (String.strip s)) with _ -> None) sl :: fll) in_channel) (* Retrieves [filename] NNet model metadata and weights wrt NNet format specification (see https://github.com/sisl/NNet for details). *) let parse_in_channel ?(permissive = false) filename in_channel = let open Result in let ok_opt r = match r with | Ok x -> Ok (Some x) | Error _ as error -> if not permissive then error else Ok None in try skip_nnet_header filename in_channel >>= fun () -> let in_channel = Csv.of_channel in_channel in handle_nnet_basic_info in_channel >>= fun (n_ls, n_is, n_os, max_l_size) -> handle_nnet_layer_sizes n_ls in_channel >>= fun layer_sizes -> handle_nnet_unused_flag in_channel >>= fun () -> ok_opt (handle_nnet_min_input_values n_is in_channel) >>= fun min_input_values -> ok_opt (handle_nnet_max_input_values n_is in_channel) >>= fun max_input_values -> ok_opt (handle_nnet_mean_values n_is in_channel) >>= fun mean_values -> ok_opt (handle_nnet_range_values n_is in_channel) >>= fun range_values -> let weights_biases = handle_nnet_weights_and_biases in_channel in Csv.close_in in_channel; Ok { n_layers = n_ls; n_inputs = n_is; n_outputs = n_os; max_layer_size = max_l_size; layer_sizes; min_input_values; max_input_values; mean_values; range_values; weights_biases; } with | Csv.Failure (_nrecord, _nfield, msg) -> Error msg | Sys_error s -> Error s | Failure msg -> Error (Format.sprintf "Unexpected error: %s" msg) let parse ?(permissive = false) filename = let in_channel = Stdlib.open_in filename in Fun.protect ~finally:(fun () -> Stdlib.close_in in_channel) (fun () -> parse_in_channel ~permissive filename in_channel)
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>