package saga

  1. Overview
  2. Docs

Source file ngram.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
type smoothing = [ `Add_k of float | `Stupid_backoff of float ]
type stats = { vocab_size : int; total_tokens : int; unique_ngrams : int }

type counts = {
  orders : (int array, (int, int) Hashtbl.t) Hashtbl.t array;
  order_totals : (int array, int) Hashtbl.t array;
}

type t = {
  order : int;
  smoothing : smoothing;
  mutable vocab_size : int;
  mutable total_tokens : int;
  counts : counts;
}

let ensure_order order =
  if order < 1 then invalid_arg "Ngram.empty: order must be >= 1"

let make_counts order =
  {
    orders = Array.init order (fun _ -> Hashtbl.create 1024);
    order_totals = Array.init order (fun _ -> Hashtbl.create 1024);
  }

let empty ~order ?(smoothing = `Add_k 0.01) () =
  ensure_order order;
  {
    order;
    smoothing;
    vocab_size = 0;
    total_tokens = 0;
    counts = make_counts order;
  }

let smoothing t = t.smoothing
let order t = t.order
let is_trained t = t.total_tokens > 0

let update_vocab_size t tokens =
  Array.iter
    (fun token -> if token + 1 > t.vocab_size then t.vocab_size <- token + 1)
    tokens

let add_to_counts t tokens =
  let len = Array.length tokens in
  t.total_tokens <- t.total_tokens + len;
  update_vocab_size t tokens;
  for i = 0 to len - 1 do
    for k = 1 to t.order do
      if i + k - 1 < len then (
        let ctx_len = k - 1 in
        let context =
          if ctx_len = 0 then [||] else Array.sub tokens i ctx_len
        in
        let next = tokens.(i + k - 1) in
        let order_idx = k - 1 in
        let orders_tbl = t.counts.orders.(order_idx) in
        let next_tbl =
          match Hashtbl.find_opt orders_tbl context with
          | Some tbl -> tbl
          | None ->
              let tbl = Hashtbl.create 8 in
              Hashtbl.add orders_tbl context tbl;
              tbl
        in
        let current =
          match Hashtbl.find_opt next_tbl next with Some c -> c | None -> 0
        in
        Hashtbl.replace next_tbl next (current + 1);
        let totals_tbl = t.counts.order_totals.(order_idx) in
        let prev_total =
          match Hashtbl.find_opt totals_tbl context with
          | Some v -> v
          | None -> 0
        in
        Hashtbl.replace totals_tbl context (prev_total + 1))
    done
  done

let add_sequence t tokens =
  add_to_counts t tokens;
  t

let of_sequences ~order ?smoothing sequences =
  let model = empty ~order ?smoothing () in
  List.iter (fun seq -> ignore (add_sequence model seq)) sequences;
  model

let stats t =
  let unique =
    Array.fold_left
      (fun acc tbl ->
        Hashtbl.fold
          (fun _ next_map acc -> acc + Hashtbl.length next_map)
          tbl acc)
      0 t.counts.orders
  in
  {
    vocab_size = t.vocab_size;
    total_tokens = t.total_tokens;
    unique_ngrams = unique;
  }

let normalise_context t context =
  if t.order = 1 then [||]
  else
    let len = Array.length context in
    let need = min (t.order - 1) len in
    if need = len then Array.copy context
    else Array.sub context (len - need) need

let rec backoff_score t context token alpha order_idx =
  if order_idx < 0 then
    if t.vocab_size = 0 then 0.0 else 1.0 /. float_of_int t.vocab_size
  else
    let orders_tbl = t.counts.orders.(order_idx) in
    let totals_tbl = t.counts.order_totals.(order_idx) in
    let ctx_len = Array.length context in
    let used_ctx =
      if ctx_len = order_idx then context
      else if order_idx = 0 then [||]
      else
        let start = max 0 (ctx_len - order_idx) in
        Array.sub context start order_idx
    in
    let next_counts =
      match Hashtbl.find_opt orders_tbl used_ctx with
      | Some tbl -> tbl
      | None -> Hashtbl.create 0
    in
    let total =
      match Hashtbl.find_opt totals_tbl used_ctx with
      | Some value -> float_of_int value
      | None -> 0.0
    in
    let c =
      match Hashtbl.find_opt next_counts token with
      | Some value -> float_of_int value
      | None -> 0.0
    in
    if c > 0.0 && total > 0.0 then c /. total
    else alpha *. backoff_score t context token alpha (order_idx - 1)

let logits t ~context =
  if not (is_trained t) then invalid_arg "Ngram.logits: model not trained";
  let context = normalise_context t context in
  match t.smoothing with
  | `Add_k k ->
      let vocab = max 1 t.vocab_size in
      let logits = Array.make vocab Float.neg_infinity in
      let orders_tbl = t.counts.orders.(t.order - 1) in
      let totals_tbl = t.counts.order_totals.(t.order - 1) in
      let total =
        match Hashtbl.find_opt totals_tbl context with
        | Some value -> float_of_int value +. (k *. float_of_int vocab)
        | None -> k *. float_of_int vocab
      in
      let counts_tbl =
        match Hashtbl.find_opt orders_tbl context with
        | Some tbl -> tbl
        | None -> Hashtbl.create 0
      in
      for token = 0 to vocab - 1 do
        let count =
          match Hashtbl.find_opt counts_tbl token with
          | Some c -> float_of_int c
          | None -> 0.0
        in
        logits.(token) <- log ((count +. k) /. total)
      done;
      logits
  | `Stupid_backoff alpha ->
      let vocab = max 1 t.vocab_size in
      Array.init vocab (fun token ->
          let prob = backoff_score t context token alpha (t.order - 1) in
          if prob <= 0.0 then Float.neg_infinity else log prob)

let log_prob t tokens =
  if not (is_trained t) then invalid_arg "Ngram.log_prob: model not trained";
  let sum = ref 0.0 in
  let len = Array.length tokens in
  for i = t.order - 1 to len - 1 do
    let context =
      if t.order = 1 then [||]
      else
        Array.sub tokens (max 0 (i - t.order + 1)) (min (t.order - 1) (i + 1))
    in
    let logits = logits t ~context in
    let token = tokens.(i) in
    if token >= 0 && token < Array.length logits then
      sum := !sum +. logits.(token)
  done;
  !sum

let perplexity t tokens =
  let len = Array.length tokens in
  if len = 0 then infinity
  else
    let lp = log_prob t tokens in
    let denom = float_of_int (max 1 (len - (t.order - 1))) in
    exp (-.lp /. denom)

let save t path =
  let oc = open_out_bin path in
  Fun.protect
    ~finally:(fun () -> close_out oc)
    (fun () -> Marshal.to_channel oc t [])

let load path =
  let ic = open_in_bin path in
  Fun.protect
    ~finally:(fun () -> close_in ic)
    (fun () ->
      let model : t = Marshal.from_channel ic in
      model)