package smtml

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file regression_model.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
(* SPDX-License-Identifier: MIT *)
(* Copyright (C) 2023-2024 formalsec *)
(* Written by Hichem Rami Ait El Hara *)

open Yojson.Safe.Util

module FeatMap = struct
  include Map.Make (String)

  let find_def0 k m = match find_opt k m with Some n -> n | None -> 0
end
(* TODO: use ints or an ADT instead of strings for keys, though strings
         give a convenient practicality. *)

type features = int FeatMap.t

type score = float

let pp_float_aux fmt f =
  if Float.is_integer f then Fmt.pf fmt "%g." f else Fmt.pf fmt "%.17g" f

let pp_float fmt f =
  if Float.compare f 0. >= 0 then Fmt.pf fmt "%a" pp_float_aux f
  else Fmt.pf fmt "(%a)" pp_float_aux f

let pp_score fmt f = Fmt.pf fmt "(score_of_float %a)" pp_float f

let compare_score = Float.compare

let score_of_int = float_of_int

let score_of_float (f : float) : score = f

let to_score = to_float

type tree =
  | Leaf of score
  | Node of
      { feature : string
      ; threshold : score
      ; left : tree
      ; right : tree
      }

type gb_model =
  { init_value : score
  ; trees : tree list
  }

type dt_model = tree

type t =
  | GBModel of gb_model
  | DTModel of dt_model

let rec pp_tree fmt = function
  | Leaf f -> Fmt.pf fmt "Leaf (%a)" pp_score f
  | Node { feature; threshold; left; right } ->
    Fmt.pf fmt "Node { feature = %S; threshold = %a; left = %a; right = %a }"
      feature pp_score threshold pp_tree left pp_tree right

let pp_gb_model fmt { init_value; trees } =
  Fmt.pf fmt "{ init_value = %a; trees = [%a] }" pp_score init_value
    (Fmt.list ~sep:(fun fmt () -> Fmt.pf fmt "; ") pp_tree)
    trees

let pp fmt = function
  | GBModel gb -> Fmt.pf fmt "GBModel (%a)" pp_gb_model gb
  | DTModel dt -> Fmt.pf fmt "DTModel (%a)" pp_tree dt

let rec tree_of_json json =
  match member "value" json with
  | `Float f -> Leaf f
  | `Int i -> Leaf (score_of_int i)
  | `Null ->
    Node
      { feature = member "feature" json |> to_string
      ; threshold = member "threshold" json |> to_score
      ; left = member "left" json |> tree_of_json
      ; right = member "right" json |> tree_of_json
      }
  | _ -> Fmt.failwith "Invalid tree structure in JSON"

let model_of_json json =
  let is_gb = member "gradient_boost" json |> to_bool in
  let model_data = member "model" json in
  if is_gb then
    GBModel
      { (* n_estimators = member "n_estimators" model_data |> to_int;  *)
        init_value = member "init_value" model_data |> to_score
      ; trees = member "trees" model_data |> to_list |> List.map tree_of_json
      }
  else DTModel (tree_of_json model_data)

let read_models_from_file filename =
  let json = Yojson.Safe.from_file filename in
  json |> to_assoc
  |> List.map (fun (solver_name, solver_json) ->
    (solver_name, model_of_json solver_json) )

let rec eval_tree (feats : int FeatMap.t) = function
  | Leaf v -> v
  | Node { feature; threshold; left; right } ->
    let value = score_of_int (FeatMap.find_def0 feature feats) in
    if compare_score value threshold <= 0 then eval_tree feats left
    else eval_tree feats right

let choose_best scores =
  match List.sort (fun (a, _) (b, _) -> compare_score a b) scores with
  | [] | [ _ ] -> assert false
  | (_, hd) :: _ -> hd

let predict (feats : int FeatMap.t) = function
  | DTModel t -> eval_tree feats t
  | GBModel gb ->
    let sum =
      List.fold_left (fun acc t -> acc +. eval_tree feats t) 0.0 gb.trees
    in
    gb.init_value +. sum