package prbnmcn-stats

  1. Overview
  2. Docs

Source file gen.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
module Make (RNG : sig
  type t

  val float : t -> float -> float

  val int : t -> int -> int

  val bool : t -> bool
end) =
struct
  module M :
    Basic_intf.Monad
      with type 'a t = (RNG.t, 'a) Stats_intf.gen
       and type 'a res = (RNG.t, 'a) Stats_intf.gen = struct
    type 'a t = (RNG.t, 'a) Stats_intf.gen

    type 'a res = (RNG.t, 'a) Stats_intf.gen

    let bind m f state =
      let x = m state in
      f x state

    let map m f state =
      let x = m state in
      f x

    let return x _state = x

    let run = Fun.id

    module Infix = struct
      let ( >>= ) = bind

      let ( >|= ) = map

      let ( let* ) = bind

      let return = return
    end
  end

  include M

  type state = RNG.t

  let iid (gen : 'a t) state =
    Seq.unfold
      (fun state ->
        let res = gen state in
        Some (res, state))
      state

  let float bound state = RNG.float state bound

  let int bound state = RNG.int state bound

  let bool = RNG.bool

  let range { Stats_intf.min; max } state =
    if max -. min >=. 0. then min +. RNG.float state (max -. min)
    else invalid_arg "uniform_in_interval"

  let _bernoulli h state =
    let x = RNG.float state 1.0 in
    x <. h

  let bernoulli h state =
    if h <. 0.0 || h >. 1.0 then invalid_arg "bernoulli" ;
    _bernoulli h state

  let geometric p state =
    if p <=. 0.0 || p >. 1.0 then invalid_arg "geometric" ;
    let failures = ref 0 in
    while not (_bernoulli p state) do
      incr failures
    done ;
    !failures

  let uniform (elts : 'a array) =
    let len = Array.length elts in
    if len = 0 then invalid_arg "uniform" ;
    fun state ->
      let i = int len state in
      elts.(i)

  let subsample ~n sampler : 'a t =
   fun rng_state ->
    let counter = ref 0 in
    let rec loop rng_state =
      let res = sampler rng_state in
      incr counter ;
      if Int.equal (!counter mod n) 0 then res else loop rng_state
    in
    loop rng_state

  let of_empirical : 'a Stats_intf.emp -> 'a t =
   fun data rng_state ->
    let len = Array.length data in
    if len = 0 then invalid_arg "of_empirical: length of data = 0" ;
    let i = RNG.int rng_state len in
    data.(i)

  module Float = struct
    let exponential ~rate : float t =
     fun rng_state ->
      if rate <=. 0.0 then invalid_arg "exponential: rate <= 0" ;
      let u = RNG.float rng_state 1.0 in
      ~-.(log u) /. rate

    let box_muller : mean:float -> std:float -> (float * float) t =
      let rec reject_loop rng_state =
        let u = RNG.float rng_state 2.0 -. 1.0 in
        let v = RNG.float rng_state 2.0 -. 1.0 in
        let s = (u *. u) +. (v *. v) in
        if s =. 0.0 || s >=. 1.0 then reject_loop rng_state
        else
          let weight = sqrt (-2. *. log s /. s) in
          let variate1 = u *. weight in
          let variate2 = v *. weight in
          (variate1, variate2)
      in
      fun ~mean ~std rng_state ->
        if std <=. 0.0 then invalid_arg "box_muller" ;
        let (v1, v2) = reject_loop rng_state in
        (mean +. (std *. v1), mean +. (std *. v2))

    type gaussgen_state = Fresh | Last of float

    let gaussian ~mean ~std : float t =
      if std <=. 0.0 then invalid_arg "gaussian: std <= 0" ;
      let state = ref Fresh in
      let gen = box_muller ~mean ~std in
      fun rng_state ->
        match !state with
        | Fresh ->
            let (x1, x2) = gen rng_state in
            state := Last x2 ;
            x1
        | Last x ->
            state := Fresh ;
            x

    let poisson ~lambda : int t =
     fun rng_state ->
      if lambda <=. 0.0 then invalid_arg "poisson: lambda <= 0" ;
      let rec loop x p s u =
        if u >. s then
          let x = x + 1 in
          let p = p *. lambda /. float_of_int x in
          let s = s +. p in
          loop x p s u
        else x
      in
      let u = RNG.float rng_state 1.0 in
      let p = exp ~-.lambda in
      let s = p in
      let x = 0 in
      loop x p s u

    (*
     * The following code for sampling Gamma RVs is due to the Owl project and is MIT licensed.
     *
     * https://github.com/owlbarn/owl/
     *
     * OWL - OCaml Scientific and Engineering Computing
     * Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
     *)

    let std_exponential_rvs state =
      let u = RNG.float state 1. in
      -.log1p (-.u)

    let std_gamma_rvs ~shape state =
      let exception Found in
      let x = ref infinity in
      (if shape =. 1. then x := std_exponential_rvs state
      else if shape <. 1. then
        try
          while true do
            let u = RNG.float state 1. in
            let v = std_exponential_rvs state in
            if u <=. 1. -. shape then (
              x := u ** (1. /. shape) ;
              if !x <=. v then raise Found)
            else
              let y = -.log ((1. -. u) /. shape) in
              x := (1. -. shape +. (shape *. y)) ** (1. /. shape) ;
              if !x <=. v +. y then raise Found
          done
        with _ -> ()
      else
        let b = shape -. (1. /. 3.) in
        let c = 1. /. sqrt (9. *. b) in
        try
          while true do
            let v = ref neg_infinity in
            while !v <=. 0. do
              x := gaussian ~mean:0.0 ~std:1.0 state ;
              v := 1. +. (c *. !x)
            done ;
            let v = !v *. !v *. !v in
            let u = RNG.float state 1. in
            if u <. 1. -. (0.0331 *. !x *. !x *. !x *. !x) then (
              x := b *. v ;
              raise Found) ;
            if log u <. (0.5 *. !x *. !x) +. (b *. (1. -. v +. log v)) then (
              x := b *. v ;
              raise Found)
          done
        with Found -> ()) ;
      !x

    let gamma ~shape ~scale state = scale *. std_gamma_rvs ~shape state

    module Alias_f =
      Alias.Make (Basic_impl.Reals.Float) (Basic_impl.Reals.Float) (M)
        (struct
          type 'a t = 'a M.t

          let mass bound state = RNG.float state bound

          let int bound state = RNG.int state bound
        end)

    let categorical cases =
      let s = Alias_f.create cases in
      fun state -> Alias_f.sampler s state

    let rec take_n n list acc =
      if n = 0 then (List.rev acc, list)
      else
        match list with
        | [] -> invalid_arg "take_n"
        | x :: tl -> take_n (n - 1) tl (x :: acc)

    let without_replacement n list rng_state =
      let (first_n, rest) = take_n n list [] in
      let reservoir = Array.of_list first_n in
      let reject = ref [] in
      List.iteri
        (fun index elt ->
          let i = n + index in
          let j = RNG.int rng_state (i + 1) in
          if j < n then (
            reject := reservoir.(j) :: !reject ;
            reservoir.(j) <- elt)
          else reject := elt :: !reject)
        rest ;
      (Array.to_list reservoir, !reject)
  end

  include Float

  module Rational = struct
    module Alias_q =
      Alias.Make (Basic_impl.Reals.Rational) (Basic_impl.Reals.Rational) (M)
        (struct
          type 'a t = 'a M.t

          let mass bound state =
            let q = Q.of_float (RNG.float state 1.0) in
            Basic_impl.Reals.Rational.(bound * q)

          let int bound state = RNG.int state bound
        end)

    let categorical list =
      let s = Alias_q.create list in
      fun state -> Alias_q.sampler s state
  end
end

include Make (Random.State)
OCaml

Innovation. Community. Security.