Source file optimizer.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
open Csir
open Optimizer_helpers
(** The optimizer simplifies a constraint system, producing an equivalent one
with fewer constraints in essentially three ways:
- 1. {!next_gate_selectors} [qlg], [qrg] and/or [qog], which
lead to a more compact representation of sums (linear combinations).
- 2. {!shared_wires} among different constraints.
- 3. {!compacting_blocks} which operate on independent wires.
{1 Optimization Rules}
{2:next_gate_selectors Introducing next-gate selectors}
In a system with 3-wires architecture, adding (in a linear combination) [n]
elements requires [n-1] constraints. For example, computing
[out = 1 + 2 x1 + 4 x2 + 8 x3 + 16 x4 + 32 x5] can be implemented in 4
constraints as follows:
{[
{ a = x1; b = x2; c = aux1; identity = [ql 2; qr 4; qo -1; qc 1] };
{ a = x3; b = x4; c = aux2; identity = [ql 8; qr 16; qo -1] };
{ a = x5; b = aux1; c = aux3; identity = [ql 32; qr 1; qo -1] };
{ a = aux2; b = aux3; c = out; identity = [ql 1; qr 1; qo -1] };
]}
Instead, by using next-gate selectors, a linear combination of [n] terms
can be implemented in [(n-1)/2] constraints. In this case, just two:
{[
{ a = x1; b = x2; c = out; identity = [ql 2; qr 4; qlg 8; qrg 16; qog 32; qo -1; qc 1] };
{ a = x3; b = x4; c = x5; identity = [] };
]}
{2:shared_wires Reusing shared-wires}
Remarkably, two linear combinations that involve some common wires can be
implemented more efficiently if they are handled together.
For example, [out1 = 5 x + 3 y + 9 z] and [out2 = 2 x - 3 y + 7 t] can be
implemented with 3 constraints (instead of 4) as follows:
{[
{ a = z; c = out1; identity = [ql 9; qlg 5; qrg 3; qog -1] };
{ a = x; b = y; c = out2; identity = [ql 2; qr -3; qrg 7; qog -1] };
{ b = t; ; identity = [] };
]}
Furthermore, observe how, conveniently, some wires in the above constraints
are unbound, which can lead to our next optimization.
{2:compacting_blocks Compacting blocks}
Consider a gate which only takes one wire as input, e.g. for computing
[out3 = 3 w + 1], implemented with constraint:
{[{ a = w; c = out3; identity = [ql 3; qo -1; qc 1] }]}
This constraint can be merged with our previous block, which only uses
wire b in its last constraint and has unbound selectors for it:
{[
{ a = z; c = out1; identity = [ql 9; qlg 5; qrg 3; qog -1] };
{ a = x; b = y; c = out2; identity = [ql 2; qr -3; qrg 7; qog -1] };
{ a = w; b = t; c = out3; identity = [ql 3; qo -1; qc 1] };
]}
{1 Algorithm}
The optimization proceeds in several steps:
1. Collect information about all linear constraints/gates (see {!is_linear}),
including the number of occurrences of their output wire among the circuit.
2. Inline any {!type:linear_computation} whose output is used only once
among the circuit, in the computation where it is used (only if the latter
is also linear).
3. Transform any non-inlined linear computation into a {!type:pseudo_block}
and join pseudo blocks that share at least a wire (preferably two wires).
4. Transform pseudo blocks to actual blocks (also considering the non-linear
computations that were left aside) and join blocks that operate on
independent wires together. *)
let list_map f l = List.rev_map f l |> List.rev
let nb_wires_arch = Csir.nb_wires_arch
type constr = Csir.CS.raw_constraint
type scalar = Csir.Scalar.t [@@deriving repr]
type block = constr array
type term = scalar * int [@@deriving repr]
(** Type to store the trace_updater information.
[free_wires] represents a list of indices of wires that have been
removed from the system and have not been used again for auxiliary
computations. If a free wire [w] is used to store an auxiliary linear
computation on some [terms], it is removed from [free_wires] and the pair
[(v, terms)] is added to the list of [assignments]. *)
type trace_info = {free_wires : int list; assignments : (int * term list) list}
[@@deriving repr]
(** A pseudo constraint represents an equation between an arbitrary number of
terms. In particular, adding all terms from its head, body, tail and its
constant term should produce 0.
The length of pc_head and pc_tail should never exceed [nb_wires_arch]. *)
type pseudo_constr = {
pc_head : term list;
pc_body : term list;
pc_tail : term list;
pc_const : Scalar.t;
}
(** A pseudo block is a list of pseudo constraints that have been groupped
in order to apply the {!shared_wires} heuristic.
The constraints in a pseudo block satisfy (by construction) the invariant
the pc_head of a constraint starts with the pc_tail of the previous one
(when considering the term variables, coefficients may be different).
Note that the pc_head may contain extra terms that do not appear in the
previous pc_tail.
This invariant is important for [block_of_pseudo_block] to be correct *)
type pseudo_block = pseudo_constr list
(** [join_list ~join scope x xs] joins [x] with another element among the first
[scope] elements of the list of candidates [xs], or returns [None] if no
union was possible. This function may not preserve the order in [xs] *)
let join_list :
join:('a -> 'a -> 'a option) -> int -> 'a -> 'a list -> 'a list option =
fun ~join scope x xs ->
let rec aux cnt seen_xs = function
| [] -> None
| x' :: xs -> (
if cnt >= scope then None
else
match join x x' with
| Some union -> Some (List.rev_append seen_xs (union :: xs))
| None -> (
match join x' x with
| Some union -> Some (List.rev_append seen_xs (union :: xs))
| None -> aux (cnt + 1) (x' :: seen_xs) xs))
in
aux 0 [] xs
(** [combine_quadratic ~join xs] takes a list of elements [xs] and returns
a (potentially) shorter list [xs'] where some of the elements have been
combined according to [join]. [join x x'] is [None] if [x] and [x'] cannot
be combined and it is [Some union] otherwise.
This algorithm tries to combine every element of [xs] with any other element
by greedily trying all possible [n choose 2] different pairings of elements,
where [n] is the length of [xs]. *)
let combine_quadratic :
scope:int -> join:('a -> 'a -> 'a option) -> 'a list -> 'a list =
fun ~scope ~join groups ->
let rec aux seen_xs = function
| [] -> seen_xs
| x :: xs -> (
match join_list scope ~join x xs with
| None -> aux (x :: seen_xs) xs
| Some xs -> aux seen_xs xs)
in
aux [] groups
(** Similar to [combine_quadratic], but more efficient.
Takes a list of [top_candidates] that will each be tried to be combined
with [rest]. This function may not preserve the order in [rest] *)
let combine_quadratic_efficient ~scope ~join top_candidates rest =
let rec aux seen top rest =
match top with
| [] -> List.rev_append (combine_quadratic ~scope ~join seen) rest
| x :: xs -> (
match join_list scope ~join x rest with
| None -> aux (x :: seen) xs rest
| Some rest -> aux seen xs rest)
in
aux [] top_candidates rest
(** Concatenates two lists of terms, merging terms with the same wire identifier
(by adding their coefficients) *)
let add_terms : term list -> term list -> term list =
fun terms1 terms2 ->
let rec aux terms (q, w) = function
| [] -> (q, w) :: terms
| (q', w') :: rest ->
if w = w' then aux terms (Scalar.add q q', w) rest
else aux ((q, w) :: terms) (q', w') rest
in
let compare (_, w1) (_, w2) = Int.compare w2 w1 in
let sorted = List.sort compare (List.rev_append terms1 terms2) in
match sorted with [] -> [] | _ -> aux [] (List.hd sorted) (List.tl sorted)
(** [inline w ~from ~into linear] takes an array of linear constraints
(a linear constraint is represented by list of terms), solves for wire
[w] in constraint [linear.(from)] and substitutes wire [w] by the result
in constraint [linear.(into)]. [linear.(from)] is left empty.
This routine should only be called if [w] does not appear in any other
constraint. In that case, constraint [linear.(from)] can then be removed
and wire [w] is said to be "free". *)
let inline w ~from ~into linear : unit =
let equals_w t = w = snd t in
let (q1, _), ts1 = find_and_remove equals_w linear.(from) in
let (q2, _), ts2 = find_and_remove equals_w linear.(into) in
let f (q, w) = (Scalar.(negate q * q2 / q1), w) in
linear.(from) <- [] ;
linear.(into) <- add_terms (List.map f ts1) ts2
(** The keys are wire indices and the values lists of constraints indices, such
that the wire appears in each of the constraints. *)
module IMap = Map.Make (Int)
module ISet = Set.Make (Int)
(** Auxiliary function used in [inline_linear].
[increase_occurrences map i terms] updates [map] by adding [i] to the
list of occurrences of wire [w] for every [(_, w)] in [terms]. *)
let increase_occurrences map i terms =
List.fold_left
(fun map (_, w) ->
IMap.update w (fun o -> Some (i :: Option.value ~default:[] o)) map)
map
terms
(** Creates a pseudo block with all terms in the body except for the wire -1
which is set as constant. *)
let pseudo_block_of_terms terms =
let qcs, terms = List.partition (fun (_, w) -> w < 0) terms in
let qc =
match qcs with
| [] -> Scalar.zero
| [(q, -1)] -> q
| _ ->
assert false
in
[{pc_head = []; pc_body = terms; pc_tail = []; pc_const = qc}]
(** Returns a list of linear gates that have possibly been inlined, the list of
wires that were inlined (and thus freed) and the list of unchanged non-linear
gates. The wires selected for inlining must:
- appear in exactly two linear gates of size one
- not be inputs
- not be range-checked
The inlined linear gates are returned as pseudo blocks because they may need
to be split again to fit 3 wires. *)
let inline_linear ~nb_inputs ~range_checked gates =
let linear, non_linear =
List.partition_map
(fun gate ->
if Array.length gate = 1 && CS.is_linear_raw_constr gate.(0) then
Left gate.(0)
else Right gate)
gates
in
let linear = Array.(map CS.linear_terms @@ of_list linear) in
let non_linear_wires =
List.concat_map CS.gate_wires non_linear |> ISet.of_list
in
let wire_occurrences =
Array.fold_left
(fun (map, i) gate -> (increase_occurrences map i gate, succ i))
(IMap.empty, 0)
linear
|> fst
|> IMap.filter (fun w _ ->
w >= nb_inputs
&& (not (ISet.mem w non_linear_wires))
&& not (Range_checks.mem w range_checked))
in
let pointers = ref IMap.empty in
let rec updated_pointer occ =
match IMap.find_opt occ !pointers with
| Some p ->
let p' = updated_pointer p in
pointers := IMap.add occ p' !pointers ;
p'
| None -> occ
in
let free_wires, removable_constrs =
IMap.fold
(fun w occs (free_wires, removable_constrs) ->
match occs with
| [occ] ->
let occ = updated_pointer occ in
(w :: free_wires, occ :: removable_constrs)
| [occ1; occ2] ->
let from = updated_pointer occ1 in
let into = updated_pointer occ2 in
inline w ~from ~into linear ;
pointers := IMap.add from into !pointers ;
(w :: free_wires, removable_constrs)
| _ -> (free_wires, removable_constrs))
wire_occurrences
([], [])
in
List.iter (fun c -> linear.(c) <- []) removable_constrs ;
let linear = List.filter (fun ts -> ts <> []) (Array.to_list linear) in
let free_wires = List.sort Int.compare free_wires in
let pseudo_blocks = List.map pseudo_block_of_terms linear in
(pseudo_blocks, non_linear, free_wires)
(** [place_over_pseudo_block ~perfectly b1 b2] tries to join pseudo block [b1]
on top of [b2], i.e., link the last pseudo constraint of [b1] with the
first pseudo constraint of [b2]. It returns an optional argument containing
the combined pseudo block if the combination was successful.
[perfectly] is a Boolean argument that specifies whether only "perfect"
matches should be accepted. In this context, "perfect" means that the two
constraints to be combined share exactly [nb_wires_arch] wires. *)
let place_over_pseudo_block ~perfectly b1 b2 =
let b1_rest, b1_k = split_last b1 in
let b2_0, b2_rest = split_first b2 in
if b1_k.pc_tail <> [] || b2_0.pc_head <> [] then None
else
let b1_body = List.map snd b1_k.pc_body in
let b2_body = List.map snd b2_0.pc_body in
let common_body = list_intersection ~equal:( = ) b1_body b2_body in
let b2_free_room = (2 * nb_wires_arch) - List.length b2_0.pc_tail in
let n =
if
List.compare_length_with b2_body b2_free_room <= 0
&& List.compare_length_with b2_body nb_wires_arch = 0
then nb_wires_arch
else nb_wires_arch - 1
in
let common = rev_first_n ~n common_body |> ISet.of_list in
let tight = ISet.cardinal common = nb_wires_arch in
if perfectly && not tight then None
else if nb_wires_arch - ISet.cardinal common > 1 then None
else
let split_body = List.partition (fun (_, w) -> ISet.mem w common) in
let qcommon1, body1 = split_body b1_k.pc_body in
let qcommon2, body2 = split_body b2_0.pc_body in
let compare_terms (_, w1) (_, w2) = Int.compare w1 w2 in
let qcommon1 = List.sort compare_terms qcommon1 in
let qcommon2 = List.sort compare_terms qcommon2 in
let pc1 = {b1_k with pc_tail = qcommon1; pc_body = body1} in
let pc2 = {b2_0 with pc_head = qcommon2; pc_body = body2} in
Some (b1_rest @ (pc1 :: pc2 :: b2_rest))
let combine_pseudo_blocks ~scope ~perfectly pseudo_blocks =
let try_both_ways b1 b2 =
match place_over_pseudo_block ~perfectly b1 b2 with
| None -> place_over_pseudo_block ~perfectly b2 b1
| Some union -> Some union
in
combine_quadratic ~scope ~join:try_both_ways pseudo_blocks
let linear_combination ~this ~next qc =
let combine l1 l2 = List.combine (list_sub (List.length l2) l1) l2 in
let pad = List.init (nb_wires_arch - List.length this) (fun _ -> 0) in
let wires = List.map snd this @ pad |> List.rev in
let this_constr_sels = List.rev CS.this_constr_linear_selectors in
let next_constr_sels = List.rev CS.next_constr_linear_selectors in
let this_sels = combine this_constr_sels (List.map fst this) in
let next_sels = combine next_constr_sels (List.map fst next) in
CS.mk_linear_constr (wires, (("qc", qc) :: this_sels) @ next_sels)
let dummy_linear_combination terms =
let pad = List.init (nb_wires_arch - List.length terms) (fun _ -> 0) in
CS.mk_linear_constr (List.map snd terms @ pad |> List.rev, [])
(** Given a pseudo block, i.e., a list of pseudo constraints, transform it into
a block. Thanks to the invariant that the pc_head of a pseudo constraint
starts with the pc_tail of the previous one, it is enough to add a constraint
for each pc_head (and encode each pc_tail with next-gate selectors).
Note that pc_bodies may need to be freed through auxiliary variables. *)
let block_of_pseudo_block trace_info pseudo_block =
let pseudo_block = Array.of_list pseudo_block in
Array.fold_left
(fun (trace_info, block, residue_pseudo_blocks, i) b ->
let is_last = i = Array.length pseudo_block - 1 in
let free_room = nb_wires_arch - List.length b.pc_head in
let free_room =
free_room + if is_last then nb_wires_arch - List.length b.pc_tail else 0
in
let aux_var_needed = List.compare_length_with b.pc_body free_room > 0 in
let n = if aux_var_needed then nb_wires_arch - 1 else nb_wires_arch in
let head, pending = complete ~n b.pc_head b.pc_body in
let tail, pending =
if is_last then complete ~n:nb_wires_arch b.pc_tail pending
else (b.pc_tail, pending)
in
let trace_info, head, residues =
if not aux_var_needed then (
assert (pending = []) ;
(trace_info, head, []))
else
let v, free_wires = split_first trace_info.free_wires in
let assignments = (v, pending) :: trace_info.assignments in
assert (List.compare_length_with head nb_wires_arch < 0) ;
let head = head @ [(Scalar.one, v)] in
let trace_info = {free_wires; assignments} in
let residues =
[pseudo_block_of_terms @@ ((Scalar.(negate one), v) :: pending)]
in
(trace_info, head, residues)
in
let constr1 = linear_combination ~this:head ~next:tail b.pc_const in
let constrs =
if is_last && tail <> [] then [constr1; dummy_linear_combination tail]
else [constr1]
in
(trace_info, block @ constrs, residues @ residue_pseudo_blocks, succ i))
(trace_info, [], [], 0)
pseudo_block
|> fun (trace_info, block, residue_pseudo_blocks, _) ->
(trace_info, Array.of_list block, residue_pseudo_blocks)
(** This function takes a list of pseudo blocks, combines them if possible and
converts them into blocks. This conversion may lead to residue pseudo blocks
(if auxiliary variables were involved to free a pc_body).
It then applies the same process on the residue pseudo blocks.
Observe that it is relatively rare that a residue pseudo block is produced
after the convertion. Consequently the second iteration of this function
will be among many fewer pseudo blocks and it quickly reaches termination *)
let blocks_of_pseudo_blocks ~scope free_wires pseudo_blocks =
let rec iteration (trace_info, blocks, pseudo_blocks) =
match pseudo_blocks with
| [] -> (trace_info, blocks)
| _ ->
let joined =
pseudo_blocks
|> combine_pseudo_blocks ~scope ~perfectly:true
|> combine_pseudo_blocks ~scope ~perfectly:false
in
List.fold_left
(fun (trace_info, blocks, new_pseudoblks) pb ->
let trace_info, block, new_pseudo =
block_of_pseudo_block trace_info pb
in
( trace_info,
block :: blocks,
List.rev_append new_pseudo new_pseudoblks ))
(trace_info, blocks, [])
joined
|> iteration
in
let trace_info = {free_wires; assignments = []} in
iteration (trace_info, [], pseudo_blocks)
(** [place_over_block ~perfectly pb pb'] tries to join block [b1] on top
of [b2]. It works only if the last constraint of [b1] has no selectors.
[perfectly] returns only matches that have no empty selectors (no holes).
It is relevant for a heuristic where we first want to find perfect matches
and then any match.
The joining is done by
- finding a permutation that makes the wires of b1.last and b2.first match
(being relaxed about unset wires)
- checking if the permutation if perfect and failing if required
- renaming the selectors of b1.last
- renaming the next selectors of b1.second_last
Example:
[ (1 1 1) [ql:a qlg:b]
(3 2 -1) [] ]
[ (2 3 4) [ql:a]
(1 1 1) [ql:a] ]
becomes
[ (1 1 1) [ql:a qrg:b]
(2 3 4) [ql:a]
(1 1 1) [ql:a] ]
*)
let place_over_block : perfectly:bool -> block -> block -> block option =
fun ~perfectly b1 b2 ->
let b1_k = b1.(Array.length b1 - 1) in
if b1_k.sels <> [] then None
else
let wires1 = CS.wires_of_constr_i b1 (Array.length b1 - 1) in
let wires2 = CS.wires_of_constr_i b2 0 in
match adapt wires1 wires2 with
| None -> None
| Some perm ->
let wires1 = permute_list perm wires1 in
let combined =
List.map2 (fun x y -> if x >= 0 then x else y) wires1 wires2
in
let is_perfect = List.for_all (fun x -> x >= 0) combined in
let b1_k_1 = b1.(Array.length b1 - 2) in
if perfectly && not is_perfect then None
else if
not (is_identity_perm perm || CS.is_arithmetic_raw_constr b1_k_1)
then None
else
let rename n =
let selectors_assoc =
List.combine
(permute_list perm CS.next_constr_linear_selectors)
CS.next_constr_linear_selectors
in
assert (List.compare_length_with selectors_assoc nb_wires_arch = 0) ;
match List.find_opt (fun (s, _) -> s = n) selectors_assoc with
| Some (_, s') -> s'
| None -> n
in
let sels_prev = List.map (fun (n, q) -> (rename n, q)) b1_k_1.sels in
let wires =
List.map (max 0) combined |> Array.of_list
in
let b2_0 = b2.(0) in
let b2_rest = Array.sub b2 1 (Array.length b2 - 1) in
assert (b1_k.precomputed_advice = []) ;
b1.(Array.length b1 - 1) <-
{
wires;
sels = b2_0.sels;
precomputed_advice = b2_0.precomputed_advice;
label = "hyb" :: b2_0.label;
} ;
b1.(Array.length b1 - 2) <- {b1_k_1 with sels = sels_prev} ;
Some (Array.append b1 b2_rest)
let combine_blocks ~perfectly blocks =
let blocks_with_empty_selectors, rest =
let open CS in
List.partition (fun b -> b.(Array.length b - 1).sels = []) blocks
in
combine_quadratic_efficient
~join:(place_over_block ~perfectly)
blocks_with_empty_selectors
rest
(** Takes {!type:trace_info} and returns a function that updates a given
trace to make it compatible with the optimized constraints. *)
let trace_updater : trace_info -> scalar array -> scalar array =
fun trace_info trace ->
let sum ~trace =
List.fold_left
(fun s (coeff, i) -> Scalar.(add s @@ mul coeff trace.(i)))
Scalar.zero
in
List.iter
(fun (i, terms) -> trace.(i) <- sum ~trace terms)
(List.rev trace_info.assignments) ;
trace
let remove_boolean_gates gates =
let gates, boolean_vars =
List.fold_left
(fun (non_boolean_gates, boolean_vars) gate ->
if Array.length gate = 1 then
match CS.boolean_raw_constr gate.(0) with
| Some b -> (non_boolean_gates, b :: boolean_vars)
| None -> (gate :: non_boolean_gates, boolean_vars)
else (gate :: non_boolean_gates, boolean_vars))
([], [])
gates
in
(List.rev gates, boolean_vars)
let add_boolean_checks ~boolean_vars gates =
let bool_vars_map = ref @@ ISet.of_list boolean_vars in
let shared_a constr =
let a_wire = CS.(constr.wires.(0)) in
if ISet.mem a_wire !bool_vars_map then (
bool_vars_map := ISet.remove a_wire !bool_vars_map ;
CS.{constr with sels = ("qbool", Scalar.one) :: constr.sels})
else constr
in
let unused_a ~gate i constr =
let wires = CS.wires_of_constr_i gate i in
if List.hd wires < 0 && ISet.cardinal !bool_vars_map > 0 then (
let a = ISet.choose !bool_vars_map in
let wires = CS.(constr.wires) in
wires.(0) <- a ;
bool_vars_map := ISet.remove a !bool_vars_map ;
CS.{constr with wires; sels = ("qbool", Scalar.one) :: constr.sels})
else constr
in
let gates = list_map (fun gate -> Array.map shared_a gate) gates in
let gates = list_map (fun gate -> Array.mapi (unused_a ~gate) gate) gates in
let new_bool_gates =
list_map CS.mk_bool_constr (ISet.to_seq !bool_vars_map |> List.of_seq)
|> Array.of_list
in
List.rev (new_bool_gates :: gates)
let inline_renamings ~nb_inputs ~range_checked gates =
let renaming_pairs =
List.filter_map
(fun gate ->
if Array.length gate = 1 && CS.is_linear_raw_constr gate.(0) then
match CS.linear_terms gate.(0) with
| [(q1, w1); (q2, w2)]
when w1 >= 0 && w2 >= 0 && Scalar.(is_zero @@ add q1 q2) ->
Some (w1, w2)
| _ -> None
else None)
gates
in
let rec build_partitions ?(partitions = IMap.empty) () =
let partitions, changed =
List.fold_left
(fun (partitions, changed) (i, j) ->
let default = ISet.empty in
let i_set = Option.value ~default @@ IMap.find_opt i partitions in
let j_set = Option.value ~default @@ IMap.find_opt j partitions in
let ij_set =
ISet.union i_set j_set |> ISet.union (ISet.of_list [i; j])
in
let i_or_j_changed =
not (ISet.equal i_set ij_set && ISet.equal j_set ij_set)
in
let partitions = IMap.add i ij_set partitions in
let partitions = IMap.add j ij_set partitions in
(partitions, changed || i_or_j_changed))
(partitions, false)
renaming_pairs
in
if changed then build_partitions ~partitions () else partitions
in
let renaming_partitions = build_partitions () in
let renaming =
IMap.mapi
(fun i i_set ->
if i < nb_inputs || Range_checks.mem i range_checked then i
else ISet.min_elt i_set)
renaming_partitions
in
let free_wires =
IMap.filter ( <> ) renaming |> IMap.bindings |> List.map fst
in
let rename i = Option.value ~default:i @@ IMap.find_opt i renaming in
(list_map (CS.rename_wires ~rename) gates, free_wires)
let remove_trivial gates =
List.filter
(fun gate ->
if Array.length gate = 1 && CS.is_linear_raw_constr gate.(0) then
match CS.linear_terms gate.(0) with
| [(q1, w1); (q2, w2)]
when w1 >= 0 && w2 >= 0 && Scalar.(is_zero @@ add q1 q2) ->
w1 <> w2
| _ -> true
else true)
gates
(** Takes a list of raw_constraints and returns an equivalent constraint system
with potentially fewer constraints.
As a second output, it returns the necessary information to build a function
that updates the trace to make it compatible with the new constraints *)
let optimize :
nb_inputs:int ->
range_checks:Range_checks.t ->
CS.gate list ->
CS.gate list * trace_info =
fun ~nb_inputs ~range_checks:range_checked gates ->
let gates, free_wires_inlining =
inline_renamings ~nb_inputs ~range_checked gates
in
let gates = remove_trivial gates in
let linear_pseudo_blocks, non_linear, free_wires =
inline_linear ~nb_inputs ~range_checked gates
in
let free_wires = List.sort Int.compare @@ free_wires_inlining @ free_wires in
let non_linear, boolean_vars = remove_boolean_gates non_linear in
let scope =
if List.compare_length_with non_linear 10_000 > 0 then 10 else 100
in
let trace_info, linear_blocks =
blocks_of_pseudo_blocks ~scope free_wires linear_pseudo_blocks
in
let cs =
List.rev_append non_linear linear_blocks
|> combine_blocks ~scope ~perfectly:true
|> combine_blocks ~scope ~perfectly:false
|> shuffle_list ~seed:1031
|> combine_blocks ~scope ~perfectly:false
|> shuffle_list ~seed:42
|> combine_blocks ~scope ~perfectly:false
in
let cs = add_boolean_checks ~boolean_vars cs in
(cs, trace_info)