package knights_tour

  1. Overview
  2. Docs

Source file stochastic_estimator.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
open Searchspace
open Collections.Util

type decision = {
    chosen: int;
    choices: int
}

type rng = int -> int

let ( let* ) = bind

let rec random_walk rng space = inspect space |> function
  | Fail -> ([], None)
	| Result x -> ([], Some x)
	| Fork choices -> 
			let num_choices = List.length choices in
			if num_choices==0 then
				([], None)
			else if num_choices==1 then
				let only_choice = List.hd choices in
				random_walk rng only_choice
			else (
				let chosen = rng num_choices in
				let chosen_el = List.nth choices chosen in
				let (recursed_path, result) = random_walk rng chosen_el in
				let path = {chosen;choices = num_choices}::recursed_path in
				(path, result)
			)

let decision_to_string {chosen;choices} =
	Int.to_string chosen ^ "/" ^ Int.to_string choices

let result_to_string to_string = function
	| Some x -> "Found: " ^ to_string x
	| None -> "Failed"

let walk_to_string to_string (path, result) =
	"[" 
	^ with_separator decision_to_string ", " path ^ 
	"] => " 
	^ result_to_string to_string result

let%expect_test "random_walk" = begin
	Random.full_init [|0|];
	let sums = (
		let* n1 = int_range 1 5 in 
		let* n2 = int_range 1 5 in
		let sum = n1+n2 in 
			(* if sum mod 2==0 then *)
				return (Printf.sprintf "%d + %d = %d" n1 n2 sum)
			(* else
				empty *)
	) in
	let do_test rng =
		let walk = random_walk rng sums in (
			Printf.printf "%s\n" (walk_to_string Fun.id walk)
		) 
	in (
		Printf.printf "Always first: ";
		do_test (fun _ -> 0);
		for _i=1 to 10 do
			Printf.printf("Random: ");
			do_test Random.int
		done;
		Printf.printf "Always last: ";
		do_test (fun bound -> bound-1)
	)
	; [%expect{|
		Always first: [0/2, 0/2] => Found: 1 + 1 = 2
		Random: [0/2, 0/2] => Found: 1 + 1 = 2
		Random: [1/2, 0/2, 0/2] => Found: 2 + 1 = 3
		Random: [0/2, 0/2] => Found: 1 + 1 = 2
		Random: [1/2, 0/2, 1/2, 1/2, 1/2, 0/2] => Found: 2 + 4 = 6
		Random: [0/2, 1/2, 1/2, 1/2, 1/2, 1/2] => Failed
		Random: [0/2, 0/2] => Found: 1 + 1 = 2
		Random: [1/2, 0/2, 0/2] => Found: 2 + 1 = 3
		Random: [0/2, 1/2, 0/2] => Found: 1 + 2 = 3
		Random: [0/2, 1/2, 1/2, 0/2] => Found: 1 + 3 = 4
		Random: [0/2, 1/2, 1/2, 1/2, 0/2] => Found: 1 + 4 = 5
		Always last: [1/2, 1/2, 1/2, 1/2, 1/2] => Failed |}]
end

let%expect_test "decision tree for sums divisible by 7" =
		let open Searchspace in
		let pp = pp_decision_tree (Format.pp_print_string) Format.std_formatter in
		let sums_div7 =
			let* n1 = int_range 1 4 in
			let* n2 = int_range 1 5 in
			let sum = n1 + n2 in
			if sum mod 7 = 0 then return (Printf.sprintf "%d + %d = %d" n1 n2 sum)
			else empty
		in begin
			sums_div7 |> pp
		end;
	[%expect{|
   choices
     choices
       FAIL
       choices
         FAIL
         choices
           FAIL
           choices
             FAIL
             choices
               FAIL
               FAIL
     choices
       choices
         FAIL
         choices
           FAIL
           choices
             FAIL
             choices
               FAIL
               choices
                 2 + 5 = 7
                 FAIL
       choices
         choices
           FAIL
           choices
             FAIL
             choices
               FAIL
               choices
                 3 + 4 = 7
                 choices
                   FAIL
                   FAIL
         choices
           choices
             FAIL
             choices
               FAIL
               choices
                 4 + 3 = 7
                 choices
                   FAIL
                   choices
                     FAIL
                     FAIL
           FAIL
   |}]

type stats = {
    nodes : int;
    forks : int;
    fails : int;
    solutions : int;
}

let rec calculate_true_values space = inspect space |> function
	| Result _ -> {nodes=1; forks=0; solutions=1; fails=0}
	| Fail -> {nodes=1; forks=0; solutions=0; fails=1}
	| Fork choices ->
			let children_stats = List.map calculate_true_values choices in
			let nodes = 1 + List.fold_left (fun acc s -> acc + s.nodes) 0 children_stats in
			let forks = 1 + List.fold_left (fun acc s -> acc + s.forks) 0 children_stats in
			let solutions = List.fold_left (fun acc s -> acc + s.solutions) 0 children_stats in
			let fails = List.fold_left (fun acc s -> acc + s.fails) 0 children_stats in
			{nodes; forks; solutions; fails}

let sums_div7 =
	let* n1 = int_range 1 4 in
	let* n2 = int_range 1 5 in
	let sum = n1 + n2 in
	if sum mod 7 = 0 then return (Printf.sprintf "%d + %d = %d" n1 n2 sum)
	else empty

type 'a node = {
		node_view : 'a Searchspace.node_view;           (* Cached inspected view of the searchspace *)
		mutable children : 'a node option array;        (* Children indexed by decision number; only some may be materialized *)
		mutable samples : int;                          (* Number of samples passing through this node *)
		mutable nodes_estimate : float;                 (* Current best estimate for subtree size *)
		mutable fail_estimate : float;                  (* Final estimate for failures in this subtree *)
		mutable solution_estimate : float;              (* Final estimate for solutions in this subtree *)
}


let child_average (children : 'a node option array) (f : 'a node -> float) : float =
	let materialized = Array.to_list children |> List.filter_map (fun c -> c) in
	let avg =
		match materialized with
		| [] -> 1.0
		| xs -> List.fold_left ( +. ) 0. (List.map f xs) /. float_of_int (List.length xs)
	in
	avg

let children_estimate (children : 'a node option array) (f : 'a node -> float) : float =
	float_of_int (Array.length children) *. child_average children f

let num_choices node_view = match node_view with
	| Fork choices -> List.length choices
	| _ -> 0

let create_node (space : 'a Searchspace.t) : 'a node =
	let node_view = inspect space in
	let (nodes_estimate, fail_estimate, solution_estimate) = match node_view with
		| Result _ -> (1.0, 0.0, 1.0)
		| Fail    -> (1.0, 1.0, 0.0)
		| Fork _  -> (1.0, 0.0, 0.0) (* initial values for forks, will be updated by sampling *)
	in {
		node_view;
		children = Array.make (num_choices node_view) None;
		samples = 0;
		nodes_estimate;
		fail_estimate;
		solution_estimate;
	}

type 'a child_selector = 'a node -> int

let uniform_selector node =
	Random.int (Array.length node.children)

let sample_rate = function
	| Some child -> float_of_int child.samples /. (child.fail_estimate +. child.solution_estimate)
	| None -> 0.0

let undersampled_selector (node : 'a node) : int =
	let n = Array.length node.children in
	if n = 0 then 0
	else
		let rates = Array.init n (fun i -> sample_rate node.children.(i)) in
		let min_rate = Array.fold_left min rates.(0) rates in
		let candidates = List.filter (fun i -> abs_float (rates.(i) -. min_rate) < 1e-8) (List.init n Fun.id) in
	List.nth candidates (Random.int (List.length candidates))


let weighted_selector (node : 'a node) : int =
	let n = Array.length node.children in
	if n = 0 then 0
	else
		let materialized = Array.to_list node.children |> List.filter_map (fun c -> c) in
		let avg =
			match materialized with
			| [] -> 1.0
			| xs -> List.fold_left ( +. ) 0. (List.map (fun c -> c.nodes_estimate) xs) /. float_of_int (List.length xs)
		in
		let weights = Array.init n (fun i ->
			match node.children.(i) with
			| Some child -> max child.nodes_estimate 1.0
			| None -> avg
		) in
		let total = Array.fold_left ( +. ) 0.0 weights in
		let r = Random.float total in
		let rec pick i acc =
			if i >= n then n - 1
			else if acc +. weights.(i) >= r then i
			else pick (i+1) (acc +. weights.(i))
		in pick 0 0.0

let rec walk select_child (node : 'a node) : unit =
	node.samples <- node.samples + 1;
	match node.node_view with
	| Fail | Result _ -> ()
	| Fork choices ->
		let num_choices = Array.length node.children in
		if num_choices > 0 then (
			let chosen = select_child node in
			let child_node = match node.children.(chosen) with
			  | Some child -> child
			  | None ->
				  let c = create_node (List.nth choices chosen) in
				  node.children.(chosen) <- Some c;
				  c
			in
			walk select_child child_node;
			node.nodes_estimate <- 1. +. children_estimate node.children (fun child -> child.nodes_estimate);
			node.fail_estimate <- children_estimate node.children (fun child -> child.fail_estimate);
			node.solution_estimate <- children_estimate node.children (fun child -> child.solution_estimate);
		)


type estimates = {
	nodes : float;
	fails : float;
	solutions : float;
	materialized_nodes : int;
}

let rec count_materialized_nodes (node : 'a node) : int =
	match node.node_view with
	| Fork _ ->
		1 + Array.fold_left (fun acc child_opt ->
			match child_opt with
			| Some child -> acc + count_materialized_nodes child
			| None -> acc
		) 0 node.children
	| _ -> 1

let estimate ?(selector=undersampled_selector) n_trials (space : 'a Searchspace.t) : estimates =
	let root = create_node space in
	for _ = 1 to n_trials do
		walk selector root
	done;
	{
		nodes = root.nodes_estimate;
		fails = root.fail_estimate;
		solutions = root.solution_estimate;
		materialized_nodes = count_materialized_nodes root;
	}

let%expect_test "estimate number of nodes" =
	let true_values = calculate_true_values sums_div7 in
	Printf.printf "True values\n";
	Printf.printf "  number of nodes: %d\n" true_values.nodes;
	Printf.printf "  number of fails: %d\n" true_values.fails;
	Printf.printf "  number of solutions: %d\n" true_values.solutions;
	Printf.printf "\n";
	let estimates = estimate 1000 sums_div7 in
	Printf.printf "Estimated\n";
	Printf.printf "  materialized nodes: %d\n" estimates.materialized_nodes;
	Printf.printf "  number of nodes: %d\n" (int_of_float (estimates.nodes +. 0.5));
	Printf.printf "  number of fails: %d\n" (int_of_float (estimates.fails +. 0.5));
	Printf.printf "  number of solutions: %d\n" (int_of_float (estimates.solutions +. 0.5));
	[%expect{|
   True values
     number of nodes: 49
     number of fails: 22
     number of solutions: 3

   Estimated
     materialized nodes: 49
     number of nodes: 49
     number of fails: 22
     number of solutions: 3
   |}]

let rec balanced_range start stop =
	if start > stop then
		empty
	else if start = stop then
		return start
	else if start + 1 = stop then
		return start ++ return stop
	else
		let mid = (start + stop) / 2 in
		balanced_range start mid ++ balanced_range (mid + 1) stop

let%expect_test "undersampling larger balanced searchspace" =
	let int_range = balanced_range in
	let right_heavy_space = (
		let* n1 = int_range 1 100 in
		let* n2 = int_range 1 100 in
		let sum = return (n1 + n2) in 
		sum |?> (fun x -> x mod 7 = 0)
	) in
		let true_values = calculate_true_values right_heavy_space in
		Printf.printf "True values\n";
		Printf.printf "  number of nodes: %d\n" true_values.nodes;
		Printf.printf "  number of fails: %d\n" true_values.fails;
		Printf.printf "  number of solutions: %d\n" true_values.solutions;
		Printf.printf "\n";
		for samplers = 1 to 5 do
			let samples = 1000 * samplers in
			Printf.printf "Sample run %d:\n" samples;
			let estimates = estimate samples right_heavy_space in
			Printf.printf "Estimated values balanced trees:\n";
			Printf.printf "  materialized nodes: %d\n" estimates.materialized_nodes;
			Printf.printf "  number of nodes: %d\n" (int_of_float (estimates.nodes +. 0.5));
			Printf.printf "  number of fails: %d\n" (int_of_float (estimates.fails +. 0.5));
			Printf.printf "  number of solutions: %d\n" (int_of_float (estimates.solutions +. 0.5));
			Printf.printf "\n";
		done;
	[%expect{|
   True values
     number of nodes: 19999
     number of fails: 8572
     number of solutions: 1428

   Sample run 1000:
   Estimated values balanced trees:
     materialized nodes: 5143
     number of nodes: 19007
     number of fails: 8228
     number of solutions: 1276

   Sample run 2000:
   Estimated values balanced trees:
     materialized nodes: 8301
     number of nodes: 19187
     number of fails: 8278
     number of solutions: 1316

   Sample run 3000:
   Estimated values balanced trees:
     materialized nodes: 10568
     number of nodes: 18661
     number of fails: 8013
     number of solutions: 1318

   Sample run 4000:
   Estimated values balanced trees:
     materialized nodes: 12268
     number of nodes: 18439
     number of fails: 7984
     number of solutions: 1236

   Sample run 5000:
   Estimated values balanced trees:
     materialized nodes: 13598
     number of nodes: 17291
     number of fails: 7448
     number of solutions: 1198
   |}]


let%expect_test "undersampling larger unbalanced searchspace" =
	let right_heavy_space = (
		let* n1 = int_range 1 100 in
		let* n2 = int_range 1 100 in
		let sum = return (n1 + n2) in 
		sum |?> (fun x -> x mod 7 = 0)
	) in
		let true_values = calculate_true_values right_heavy_space in
		Printf.printf "True values\n";
		Printf.printf "  number of nodes: %d\n" true_values.nodes;
		Printf.printf "  number of fails: %d\n" true_values.fails;
		Printf.printf "  number of solutions: %d\n" true_values.solutions;
		Printf.printf "\n";
		for samplers = 1 to 5 do
			let samples = 1000 * samplers in
			Printf.printf "Sample run %d:\n" samples;
			let estimates = estimate samples right_heavy_space in
			Printf.printf "Estimated values (unbalanced trees):\n";
			Printf.printf "  materialized nodes: %d\n" estimates.materialized_nodes;
			Printf.printf "  number of nodes: %d\n" (int_of_float (estimates.nodes +. 0.5));
			Printf.printf "  number of fails: %d\n" (int_of_float (estimates.fails +. 0.5));
			Printf.printf "  number of solutions: %d\n" (int_of_float (estimates.solutions +. 0.5));
			Printf.printf "\n";
		done;
	[%expect{|
   True values
     number of nodes: 20201
     number of fails: 8673
     number of solutions: 1428

   Sample run 1000:
   Estimated values (unbalanced trees):
     materialized nodes: 2099
     number of nodes: 2199
     number of fails: 946
     number of solutions: 154

   Sample run 2000:
   Estimated values (unbalanced trees):
     materialized nodes: 4100
     number of nodes: 4203
     number of fails: 1798
     number of solutions: 304

   Sample run 3000:
   Estimated values (unbalanced trees):
     materialized nodes: 6100
     number of nodes: 6203
     number of fails: 2651
     number of solutions: 451

   Sample run 4000:
   Estimated values (unbalanced trees):
     materialized nodes: 8099
     number of nodes: 8199
     number of fails: 3514
     number of solutions: 586

   Sample run 5000:
   Estimated values (unbalanced trees):
     materialized nodes: 10099
     number of nodes: 10199
     number of fails: 4378
     number of solutions: 722
   |}]

(** Incremental estimator API implementation *)
type 'a t = {
	root : 'a node;
	selector : 'a child_selector;
}

let create ?(selector=undersampled_selector) (space : 'a Searchspace.t) : 'a t =
	{ root = create_node space; selector }

let sample n (est : 'a t) : unit =
	for _ = 1 to n do
		walk est.selector est.root
	done

let estimates (est : 'a t) : estimates =
	{
		nodes = est.root.nodes_estimate;
		fails = est.root.fail_estimate;
		solutions = est.root.solution_estimate;
		materialized_nodes = count_materialized_nodes est.root;
	}

let%expect_test "incremental estimator API on unbalanced searchspace" =
  let right_heavy_space = (
    let* n1 = int_range 1 100 in
    let* n2 = int_range 1 100 in
    let sum = return (n1 + n2) in
    sum |?> (fun x -> x mod 7 = 0)
  ) in
  let true_values = calculate_true_values right_heavy_space in
  Printf.printf "True values\n";
  Printf.printf "  number of nodes: %d\n" true_values.nodes;
  Printf.printf "  number of fails: %d\n" true_values.fails;
  Printf.printf "  number of solutions: %d\n" true_values.solutions;
  Printf.printf "\n";
  let est = create right_heavy_space in
  for samplers = 1 to 5 do
    let samples = 1000 * samplers in
    sample 1000 est;
    Printf.printf "Sample run %d:\n" samples;
    let estimates = estimates est in
    Printf.printf "Estimated values (incremental):\n";
    Printf.printf "  materialized nodes: %d\n" estimates.materialized_nodes;
    Printf.printf "  number of nodes: %d\n" (int_of_float (estimates.nodes +. 0.5));
    Printf.printf "  number of fails: %d\n" (int_of_float (estimates.fails +. 0.5));
    Printf.printf "  number of solutions: %d\n" (int_of_float (estimates.solutions +. 0.5));
    Printf.printf "\n";
  done;
  [%expect{|
    True values
      number of nodes: 20201
      number of fails: 8673
      number of solutions: 1428

    Sample run 1000:
    Estimated values (incremental):
      materialized nodes: 2099
      number of nodes: 2199
      number of fails: 946
      number of solutions: 154

    Sample run 2000:
    Estimated values (incremental):
      materialized nodes: 4101
      number of nodes: 4211
      number of fails: 1798
      number of solutions: 308

    Sample run 3000:
    Estimated values (incremental):
      materialized nodes: 6099
      number of nodes: 6199
      number of fails: 2665
      number of solutions: 435

    Sample run 4000:
    Estimated values (incremental):
      materialized nodes: 8099
      number of nodes: 8199
      number of fails: 3515
      number of solutions: 585

    Sample run 5000:
    Estimated values (incremental):
      materialized nodes: 10099
      number of nodes: 10199
      number of fails: 4372
      number of solutions: 728
    |}]