Source file stochastic_estimator.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
open Searchspace
open Collections.Util
type decision = {
chosen: int;
choices: int
}
type rng = int -> int
let ( let* ) = bind
let rec random_walk rng space = inspect space |> function
| Fail -> ([], None)
| Result x -> ([], Some x)
| Fork choices ->
let num_choices = List.length choices in
if num_choices==0 then
([], None)
else if num_choices==1 then
let only_choice = List.hd choices in
random_walk rng only_choice
else (
let chosen = rng num_choices in
let chosen_el = List.nth choices chosen in
let (recursed_path, result) = random_walk rng chosen_el in
let path = {chosen;choices = num_choices}::recursed_path in
(path, result)
)
let decision_to_string {chosen;choices} =
Int.to_string chosen ^ "/" ^ Int.to_string choices
let result_to_string to_string = function
| Some x -> "Found: " ^ to_string x
| None -> "Failed"
let walk_to_string to_string (path, result) =
"["
^ with_separator decision_to_string ", " path ^
"] => "
^ result_to_string to_string result
let%expect_test "random_walk" = begin
Random.full_init [|0|];
let sums = (
let* n1 = int_range 1 5 in
let* n2 = int_range 1 5 in
let sum = n1+n2 in
return (Printf.sprintf "%d + %d = %d" n1 n2 sum)
) in
let do_test rng =
let walk = random_walk rng sums in (
Printf.printf "%s\n" (walk_to_string Fun.id walk)
)
in (
Printf.printf "Always first: ";
do_test (fun _ -> 0);
for _i=1 to 10 do
Printf.printf("Random: ");
do_test Random.int
done;
Printf.printf "Always last: ";
do_test (fun bound -> bound-1)
)
; [%expect{|
Always first: [0/2, 0/2] => Found: 1 + 1 = 2
Random: [0/2, 0/2] => Found: 1 + 1 = 2
Random: [1/2, 0/2, 0/2] => Found: 2 + 1 = 3
Random: [0/2, 0/2] => Found: 1 + 1 = 2
Random: [1/2, 0/2, 1/2, 1/2, 1/2, 0/2] => Found: 2 + 4 = 6
Random: [0/2, 1/2, 1/2, 1/2, 1/2, 1/2] => Failed
Random: [0/2, 0/2] => Found: 1 + 1 = 2
Random: [1/2, 0/2, 0/2] => Found: 2 + 1 = 3
Random: [0/2, 1/2, 0/2] => Found: 1 + 2 = 3
Random: [0/2, 1/2, 1/2, 0/2] => Found: 1 + 3 = 4
Random: [0/2, 1/2, 1/2, 1/2, 0/2] => Found: 1 + 4 = 5
Always last: [1/2, 1/2, 1/2, 1/2, 1/2] => Failed |}]
end
let%expect_test "decision tree for sums divisible by 7" =
let open Searchspace in
let pp = pp_decision_tree (Format.pp_print_string) Format.std_formatter in
let sums_div7 =
let* n1 = int_range 1 4 in
let* n2 = int_range 1 5 in
let sum = n1 + n2 in
if sum mod 7 = 0 then return (Printf.sprintf "%d + %d = %d" n1 n2 sum)
else empty
in begin
sums_div7 |> pp
end;
[%expect{|
choices
choices
FAIL
choices
FAIL
choices
FAIL
choices
FAIL
choices
FAIL
FAIL
choices
choices
FAIL
choices
FAIL
choices
FAIL
choices
FAIL
choices
2 + 5 = 7
FAIL
choices
choices
FAIL
choices
FAIL
choices
FAIL
choices
3 + 4 = 7
choices
FAIL
FAIL
choices
choices
FAIL
choices
FAIL
choices
4 + 3 = 7
choices
FAIL
choices
FAIL
FAIL
FAIL
|}]
type stats = {
nodes : int;
forks : int;
fails : int;
solutions : int;
}
let rec calculate_true_values space = inspect space |> function
| Result _ -> {nodes=1; forks=0; solutions=1; fails=0}
| Fail -> {nodes=1; forks=0; solutions=0; fails=1}
| Fork choices ->
let children_stats = List.map calculate_true_values choices in
let nodes = 1 + List.fold_left (fun acc s -> acc + s.nodes) 0 children_stats in
let forks = 1 + List.fold_left (fun acc s -> acc + s.forks) 0 children_stats in
let solutions = List.fold_left (fun acc s -> acc + s.solutions) 0 children_stats in
let fails = List.fold_left (fun acc s -> acc + s.fails) 0 children_stats in
{nodes; forks; solutions; fails}
let sums_div7 =
let* n1 = int_range 1 4 in
let* n2 = int_range 1 5 in
let sum = n1 + n2 in
if sum mod 7 = 0 then return (Printf.sprintf "%d + %d = %d" n1 n2 sum)
else empty
type 'a node = {
node_view : 'a Searchspace.node_view;
mutable children : 'a node option array;
mutable samples : int;
mutable nodes_estimate : float;
mutable fail_estimate : float;
mutable solution_estimate : float;
}
let child_average (children : 'a node option array) (f : 'a node -> float) : float =
let materialized = Array.to_list children |> List.filter_map (fun c -> c) in
let avg =
match materialized with
| [] -> 1.0
| xs -> List.fold_left ( +. ) 0. (List.map f xs) /. float_of_int (List.length xs)
in
avg
let children_estimate (children : 'a node option array) (f : 'a node -> float) : float =
float_of_int (Array.length children) *. child_average children f
let num_choices node_view = match node_view with
| Fork choices -> List.length choices
| _ -> 0
let create_node (space : 'a Searchspace.t) : 'a node =
let node_view = inspect space in
let (nodes_estimate, fail_estimate, solution_estimate) = match node_view with
| Result _ -> (1.0, 0.0, 1.0)
| Fail -> (1.0, 1.0, 0.0)
| Fork _ -> (1.0, 0.0, 0.0)
in {
node_view;
children = Array.make (num_choices node_view) None;
samples = 0;
nodes_estimate;
fail_estimate;
solution_estimate;
}
type 'a child_selector = 'a node -> int
let uniform_selector node =
Random.int (Array.length node.children)
let sample_rate = function
| Some child -> float_of_int child.samples /. (child.fail_estimate +. child.solution_estimate)
| None -> 0.0
let undersampled_selector (node : 'a node) : int =
let n = Array.length node.children in
if n = 0 then 0
else
let rates = Array.init n (fun i -> sample_rate node.children.(i)) in
let min_rate = Array.fold_left min rates.(0) rates in
let candidates = List.filter (fun i -> abs_float (rates.(i) -. min_rate) < 1e-8) (List.init n Fun.id) in
List.nth candidates (Random.int (List.length candidates))
let weighted_selector (node : 'a node) : int =
let n = Array.length node.children in
if n = 0 then 0
else
let materialized = Array.to_list node.children |> List.filter_map (fun c -> c) in
let avg =
match materialized with
| [] -> 1.0
| xs -> List.fold_left ( +. ) 0. (List.map (fun c -> c.nodes_estimate) xs) /. float_of_int (List.length xs)
in
let weights = Array.init n (fun i ->
match node.children.(i) with
| Some child -> max child.nodes_estimate 1.0
| None -> avg
) in
let total = Array.fold_left ( +. ) 0.0 weights in
let r = Random.float total in
let rec pick i acc =
if i >= n then n - 1
else if acc +. weights.(i) >= r then i
else pick (i+1) (acc +. weights.(i))
in pick 0 0.0
let rec walk select_child (node : 'a node) : unit =
node.samples <- node.samples + 1;
match node.node_view with
| Fail | Result _ -> ()
| Fork choices ->
let num_choices = Array.length node.children in
if num_choices > 0 then (
let chosen = select_child node in
let child_node = match node.children.(chosen) with
| Some child -> child
| None ->
let c = create_node (List.nth choices chosen) in
node.children.(chosen) <- Some c;
c
in
walk select_child child_node;
node.nodes_estimate <- 1. +. children_estimate node.children (fun child -> child.nodes_estimate);
node.fail_estimate <- children_estimate node.children (fun child -> child.fail_estimate);
node.solution_estimate <- children_estimate node.children (fun child -> child.solution_estimate);
)
type estimates = {
nodes : float;
fails : float;
solutions : float;
materialized_nodes : int;
}
let rec count_materialized_nodes (node : 'a node) : int =
match node.node_view with
| Fork _ ->
1 + Array.fold_left (fun acc child_opt ->
match child_opt with
| Some child -> acc + count_materialized_nodes child
| None -> acc
) 0 node.children
| _ -> 1
let estimate ?(selector=undersampled_selector) n_trials (space : 'a Searchspace.t) : estimates =
let root = create_node space in
for _ = 1 to n_trials do
walk selector root
done;
{
nodes = root.nodes_estimate;
fails = root.fail_estimate;
solutions = root.solution_estimate;
materialized_nodes = count_materialized_nodes root;
}
let%expect_test "estimate number of nodes" =
let true_values = calculate_true_values sums_div7 in
Printf.printf "True values\n";
Printf.printf " number of nodes: %d\n" true_values.nodes;
Printf.printf " number of fails: %d\n" true_values.fails;
Printf.printf " number of solutions: %d\n" true_values.solutions;
Printf.printf "\n";
let estimates = estimate 1000 sums_div7 in
Printf.printf "Estimated\n";
Printf.printf " materialized nodes: %d\n" estimates.materialized_nodes;
Printf.printf " number of nodes: %d\n" (int_of_float (estimates.nodes +. 0.5));
Printf.printf " number of fails: %d\n" (int_of_float (estimates.fails +. 0.5));
Printf.printf " number of solutions: %d\n" (int_of_float (estimates.solutions +. 0.5));
[%expect{|
True values
number of nodes: 49
number of fails: 22
number of solutions: 3
Estimated
materialized nodes: 49
number of nodes: 49
number of fails: 22
number of solutions: 3
|}]
let rec balanced_range start stop =
if start > stop then
empty
else if start = stop then
return start
else if start + 1 = stop then
return start ++ return stop
else
let mid = (start + stop) / 2 in
balanced_range start mid ++ balanced_range (mid + 1) stop
let%expect_test "undersampling larger balanced searchspace" =
let int_range = balanced_range in
let right_heavy_space = (
let* n1 = int_range 1 100 in
let* n2 = int_range 1 100 in
let sum = return (n1 + n2) in
sum |?> (fun x -> x mod 7 = 0)
) in
let true_values = calculate_true_values right_heavy_space in
Printf.printf "True values\n";
Printf.printf " number of nodes: %d\n" true_values.nodes;
Printf.printf " number of fails: %d\n" true_values.fails;
Printf.printf " number of solutions: %d\n" true_values.solutions;
Printf.printf "\n";
for samplers = 1 to 5 do
let samples = 1000 * samplers in
Printf.printf "Sample run %d:\n" samples;
let estimates = estimate samples right_heavy_space in
Printf.printf "Estimated values balanced trees:\n";
Printf.printf " materialized nodes: %d\n" estimates.materialized_nodes;
Printf.printf " number of nodes: %d\n" (int_of_float (estimates.nodes +. 0.5));
Printf.printf " number of fails: %d\n" (int_of_float (estimates.fails +. 0.5));
Printf.printf " number of solutions: %d\n" (int_of_float (estimates.solutions +. 0.5));
Printf.printf "\n";
done;
[%expect{|
True values
number of nodes: 19999
number of fails: 8572
number of solutions: 1428
Sample run 1000:
Estimated values balanced trees:
materialized nodes: 5143
number of nodes: 19007
number of fails: 8228
number of solutions: 1276
Sample run 2000:
Estimated values balanced trees:
materialized nodes: 8301
number of nodes: 19187
number of fails: 8278
number of solutions: 1316
Sample run 3000:
Estimated values balanced trees:
materialized nodes: 10568
number of nodes: 18661
number of fails: 8013
number of solutions: 1318
Sample run 4000:
Estimated values balanced trees:
materialized nodes: 12268
number of nodes: 18439
number of fails: 7984
number of solutions: 1236
Sample run 5000:
Estimated values balanced trees:
materialized nodes: 13598
number of nodes: 17291
number of fails: 7448
number of solutions: 1198
|}]
let%expect_test "undersampling larger unbalanced searchspace" =
let right_heavy_space = (
let* n1 = int_range 1 100 in
let* n2 = int_range 1 100 in
let sum = return (n1 + n2) in
sum |?> (fun x -> x mod 7 = 0)
) in
let true_values = calculate_true_values right_heavy_space in
Printf.printf "True values\n";
Printf.printf " number of nodes: %d\n" true_values.nodes;
Printf.printf " number of fails: %d\n" true_values.fails;
Printf.printf " number of solutions: %d\n" true_values.solutions;
Printf.printf "\n";
for samplers = 1 to 5 do
let samples = 1000 * samplers in
Printf.printf "Sample run %d:\n" samples;
let estimates = estimate samples right_heavy_space in
Printf.printf "Estimated values (unbalanced trees):\n";
Printf.printf " materialized nodes: %d\n" estimates.materialized_nodes;
Printf.printf " number of nodes: %d\n" (int_of_float (estimates.nodes +. 0.5));
Printf.printf " number of fails: %d\n" (int_of_float (estimates.fails +. 0.5));
Printf.printf " number of solutions: %d\n" (int_of_float (estimates.solutions +. 0.5));
Printf.printf "\n";
done;
[%expect{|
True values
number of nodes: 20201
number of fails: 8673
number of solutions: 1428
Sample run 1000:
Estimated values (unbalanced trees):
materialized nodes: 2099
number of nodes: 2199
number of fails: 946
number of solutions: 154
Sample run 2000:
Estimated values (unbalanced trees):
materialized nodes: 4100
number of nodes: 4203
number of fails: 1798
number of solutions: 304
Sample run 3000:
Estimated values (unbalanced trees):
materialized nodes: 6100
number of nodes: 6203
number of fails: 2651
number of solutions: 451
Sample run 4000:
Estimated values (unbalanced trees):
materialized nodes: 8099
number of nodes: 8199
number of fails: 3514
number of solutions: 586
Sample run 5000:
Estimated values (unbalanced trees):
materialized nodes: 10099
number of nodes: 10199
number of fails: 4378
number of solutions: 722
|}]
(** Incremental estimator API implementation *)
type 'a t = {
root : 'a node;
selector : 'a child_selector;
}
let create ?(selector=undersampled_selector) (space : 'a Searchspace.t) : 'a t =
{ root = create_node space; selector }
let sample n (est : 'a t) : unit =
for _ = 1 to n do
walk est.selector est.root
done
let estimates (est : 'a t) : estimates =
{
nodes = est.root.nodes_estimate;
fails = est.root.fail_estimate;
solutions = est.root.solution_estimate;
materialized_nodes = count_materialized_nodes est.root;
}
let%expect_test "incremental estimator API on unbalanced searchspace" =
let right_heavy_space = (
let* n1 = int_range 1 100 in
let* n2 = int_range 1 100 in
let sum = return (n1 + n2) in
sum |?> (fun x -> x mod 7 = 0)
) in
let true_values = calculate_true_values right_heavy_space in
Printf.printf "True values\n";
Printf.printf " number of nodes: %d\n" true_values.nodes;
Printf.printf " number of fails: %d\n" true_values.fails;
Printf.printf " number of solutions: %d\n" true_values.solutions;
Printf.printf "\n";
let est = create right_heavy_space in
for samplers = 1 to 5 do
let samples = 1000 * samplers in
sample 1000 est;
Printf.printf "Sample run %d:\n" samples;
let estimates = estimates est in
Printf.printf "Estimated values (incremental):\n";
Printf.printf " materialized nodes: %d\n" estimates.materialized_nodes;
Printf.printf " number of nodes: %d\n" (int_of_float (estimates.nodes +. 0.5));
Printf.printf " number of fails: %d\n" (int_of_float (estimates.fails +. 0.5));
Printf.printf " number of solutions: %d\n" (int_of_float (estimates.solutions +. 0.5));
Printf.printf "\n";
done;
[%expect{|
True values
number of nodes: 20201
number of fails: 8673
number of solutions: 1428
Sample run 1000:
Estimated values (incremental):
materialized nodes: 2099
number of nodes: 2199
number of fails: 946
number of solutions: 154
Sample run 2000:
Estimated values (incremental):
materialized nodes: 4101
number of nodes: 4211
number of fails: 1798
number of solutions: 308
Sample run 3000:
Estimated values (incremental):
materialized nodes: 6099
number of nodes: 6199
number of fails: 2665
number of solutions: 435
Sample run 4000:
Estimated values (incremental):
materialized nodes: 8099
number of nodes: 8199
number of fails: 3515
number of solutions: 585
Sample run 5000:
Estimated values (incremental):
materialized nodes: 10099
number of nodes: 10199
number of fails: 4372
number of solutions: 728
|}]