package saga

  1. Overview
  2. Docs

Source file bpe.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
(** Byte Pair Encoding implementation *)
module IntPair = struct
  type t = int * int

  let compare = compare
end

module IntPairMap = Map.Make (IntPair)
module IntPairSet = Set.Make (IntPair)

module IntSet = Set.Make (struct
  type t = int

  let compare = compare
end)

module StringMap = Map.Make (String)

type vocab = (string, int) Hashtbl.t
type vocab_r = (int, string) Hashtbl.t
type merges = (string * string) list
type merge_map = (int * int) IntPairMap.t

type symbol = {
  mutable c : int;
  mutable prev : int;
  mutable next : int;
  mutable len : int;
}

type word = { mutable symbols : symbol array; mutable size : int }
type token = { id : int; value : string; offsets : int * int }
type cache_entry = word

type config = {
  vocab : vocab;
  merges : merges;
  cache_capacity : int;
  dropout : float option;
  unk_token : string option;
  continuing_subword_prefix : string option;
  end_of_word_suffix : string option;
  fuse_unk : bool;
  byte_fallback : bool;
  ignore_merges : bool;
}

type t = {
  vocab : vocab;
  vocab_r : vocab_r;
  merges : merge_map;
  cache : (string, cache_entry) Hashtbl.t option;
  dropout : float option;
  unk_token : string option;
  continuing_subword_prefix : string option;
  end_of_word_suffix : string option;
  fuse_unk : bool;
  byte_fallback : bool;
  ignore_merges : bool;
}

let create_word capacity =
  {
    symbols = Array.make capacity { c = -1; prev = -1; next = -1; len = 0 };
    size = 0;
  }

let add_symbol word c byte_len =
  if word.size >= Array.length word.symbols then
    failwith "Word capacity exceeded";
  let prev = if word.size > 0 then word.size - 1 else -1 in
  let symbol = { c; prev; next = -1; len = byte_len } in
  if prev >= 0 then word.symbols.(prev).next <- word.size;
  word.symbols.(word.size) <- symbol;
  word.size <- word.size + 1

module PQueue = struct
  let create () = ref []
  let push t x = t := x :: !t

  let pop cmp t =
    if !t = [] then None
    else (
      t := List.sort cmp !t;
      let h = List.hd !t in
      t := List.tl !t;
      Some h)
end

let apply_merges model dropout word =
  let p = match dropout with Some p -> p | None -> 0.0 in
  let cmp (r1, p1, _) (r2, p2, _) =
    let c = r1 - r2 in
    if c = 0 then p1 - p2 else c
  in
  let queue = PQueue.create () in
  for i = 0 to word.size - 2 do
    if word.symbols.(i).len > 0 && word.symbols.(i + 1).len > 0 then
      let pair = (word.symbols.(i).c, word.symbols.(i + 1).c) in
      match IntPairMap.find_opt pair model.merges with
      | Some (rank, new_id) -> PQueue.push queue (rank, i, new_id)
      | None -> ()
  done;
  let skips = ref [] in
  let rec process_queue () =
    match PQueue.pop cmp queue with
    | None -> ()
    | Some top -> (
        let rank, pos, new_id = top in
        if word.symbols.(pos).len = 0 then process_queue ()
        else
          let next_pos = word.symbols.(pos).next in
          if next_pos = -1 then process_queue ()
          else
            let next_pos = next_pos in
            let cur_pair = (word.symbols.(pos).c, word.symbols.(next_pos).c) in
            match IntPairMap.find_opt cur_pair model.merges with
            | Some (r, nid) when r = rank && nid = new_id ->
                if Random.float 1.0 < p then skips := top :: !skips
                else (
                  List.iter (PQueue.push queue) !skips;
                  skips := [];
                  word.symbols.(pos).c <- new_id;
                  word.symbols.(pos).len <-
                    word.symbols.(pos).len + word.symbols.(next_pos).len;
                  word.symbols.(pos).next <- word.symbols.(next_pos).next;
                  word.symbols.(next_pos).len <- 0;
                  if word.symbols.(pos).next >= 0 then
                    word.symbols.(word.symbols.(pos).next).prev <- pos;
                  (if word.symbols.(pos).prev >= 0 then
                     let prev = word.symbols.(pos).prev in
                     let pair = (word.symbols.(prev).c, word.symbols.(pos).c) in
                     match IntPairMap.find_opt pair model.merges with
                     | Some (r, nid) -> PQueue.push queue (r, prev, nid)
                     | None -> ());
                  let next = word.symbols.(pos).next in
                  if next >= 0 then
                    let pair = (word.symbols.(pos).c, word.symbols.(next).c) in
                    match IntPairMap.find_opt pair model.merges with
                    | Some (r, nid) -> PQueue.push queue (r, pos, nid)
                    | None -> ());
                process_queue () (* Continue processing the queue *)
            | _ -> process_queue ())
  in
  process_queue ();
  let new_symbols = Array.make word.size word.symbols.(0) in
  let j = ref 0 in
  for k = 0 to word.size - 1 do
    if word.symbols.(k).len > 0 then (
      new_symbols.(!j) <- word.symbols.(k);
      incr j)
  done;
  word.symbols <- Array.sub new_symbols 0 !j;
  word.size <- !j

let merge_word model text =
  let len = String.length text in
  let word = create_word len in
  let decoder = Uutf.decoder (`String text) in
  let i = ref 0 in
  let pending_unk = ref None in
  let flush_unk () =
    match !pending_unk with
    | Some (unk_id, unk_len) ->
        add_symbol word unk_id unk_len;
        pending_unk := None
    | None -> ()
  in
  let rec process_chars () =
    match Uutf.decode decoder with
    | `Uchar u ->
        let start = !i in
        let char_str =
          let buf = Buffer.create 4 in
          Uutf.Buffer.add_utf_8 buf u;
          Buffer.contents buf
        in
        let byte_len = String.length char_str in
        i := !i + byte_len;
        let is_first = start = 0 in
        let is_last = !i >= len in
        let token_str =
          let s = ref char_str in
          (if not is_first then
             match model.continuing_subword_prefix with
             | Some prefix -> s := prefix ^ !s
             | None -> ());
          (if is_last then
             match model.end_of_word_suffix with
             | Some suffix -> s := !s ^ suffix
             | None -> ());
          !s
        in
        let unk_handling () =
          match model.unk_token with
          | Some unk -> (
              match Hashtbl.find_opt model.vocab unk with
              | Some unk_id ->
                  if model.fuse_unk then
                    pending_unk :=
                      Some
                        (match !pending_unk with
                        | Some (id, len) -> (id, len + byte_len)
                        | None -> (unk_id, byte_len))
                  else (
                    flush_unk ();
                    add_symbol word unk_id byte_len)
              | None ->
                  failwith
                    (Printf.sprintf "Unknown token '%s' not in vocabulary" unk))
          | None -> ()
        in
        (match Hashtbl.find_opt model.vocab token_str with
        | Some id ->
            flush_unk ();
            add_symbol word id byte_len
        | None ->
            if model.byte_fallback then
              let byte_ids_opt =
                let rec loop acc idx =
                  if idx = byte_len then Some (List.rev acc)
                  else
                    let byte = char_str.[idx] in
                    let hex = Printf.sprintf "<0x%02X>" (Char.code byte) in
                    match Hashtbl.find_opt model.vocab hex with
                    | Some id -> loop (id :: acc) (idx + 1)
                    | None -> None
                in
                loop [] 0
              in
              match byte_ids_opt with
              | Some ids ->
                  flush_unk ();
                  List.iter (fun id -> add_symbol word id 1) ids
              | None -> unk_handling ()
            else unk_handling ());
        process_chars ()
    | `End -> flush_unk ()
    | `Malformed _ -> process_chars ()
    | `Await -> assert false
  in
  process_chars ();
  apply_merges model model.dropout word;
  word

let word_to_tokens model word =
  let tokens = ref [] in
  let offset = ref 0 in
  for i = 0 to word.size - 1 do
    if word.symbols.(i).len > 0 then (
      let id = word.symbols.(i).c in
      let value =
        match Hashtbl.find_opt model.vocab_r id with
        | Some v -> v
        | None -> "<unk>"
      in
      let start = !offset in
      let end_ = !offset + word.symbols.(i).len in
      tokens := { id; value; offsets = (start, end_) } :: !tokens;
      offset := end_)
  done;
  List.rev !tokens

let tokenize model text =
  if String.length text = 0 then []
  else
    (* First check if the entire text is in the vocabulary *)
    match Hashtbl.find_opt model.vocab text with
    | Some id -> [ { id; value = text; offsets = (0, String.length text) } ]
    | None -> (
        if
          (* If not, apply BPE merges *)
          model.ignore_merges
        then word_to_tokens model (merge_word model text)
        else
          match model.cache with
          | Some cache when String.length text < 1000 -> (
              match Hashtbl.find_opt cache text with
              | Some word -> word_to_tokens model word
              | None ->
                  let word = merge_word model text in
                  Hashtbl.add cache text word;
                  word_to_tokens model word)
          | _ ->
              let word = merge_word model text in
              word_to_tokens model word)

let token_to_id model token = Hashtbl.find_opt model.vocab token
let id_to_token model id = Hashtbl.find_opt model.vocab_r id
let get_vocab model = Hashtbl.fold (fun k v acc -> (k, v) :: acc) model.vocab []
let get_vocab_size model = Hashtbl.length model.vocab
let get_unk_token model = model.unk_token
let get_continuing_subword_prefix model = model.continuing_subword_prefix
let get_end_of_word_suffix model = model.end_of_word_suffix

let clear_cache model =
  match model.cache with Some cache -> Hashtbl.clear cache | None -> ()

let resize_cache model _capacity =
  match model.cache with Some cache -> Hashtbl.clear cache | None -> ()

let convert_merges_to_merge_map vocab merges continuing_subword_prefix =
  let prefix_len =
    match continuing_subword_prefix with Some p -> String.length p | None -> 0
  in
  List.mapi
    (fun rank (a, b) ->
      match (Hashtbl.find_opt vocab a, Hashtbl.find_opt vocab b) with
      | Some a_id, Some b_id -> (
          let new_token =
            if prefix_len > 0 && String.length b > prefix_len then
              a ^ String.sub b prefix_len (String.length b - prefix_len)
            else a ^ b
          in
          match Hashtbl.find_opt vocab new_token with
          | Some new_id -> Some ((a_id, b_id), (rank, new_id))
          | None ->
              failwith
                (Printf.sprintf "Merge token '%s' not in vocabulary" new_token))
      | _ -> failwith (Printf.sprintf "Merge tokens not in vocabulary"))
    merges
  |> List.filter_map (fun x -> x)
  |> List.fold_left (fun acc (k, v) -> IntPairMap.add k v acc) IntPairMap.empty

let create (cfg : config) : t =
  let vocab_r = Hashtbl.create (Hashtbl.length cfg.vocab) in
  Hashtbl.iter (fun k v -> Hashtbl.add vocab_r v k) cfg.vocab;
  let cache =
    if cfg.cache_capacity = 0 then None
    else Some (Hashtbl.create cfg.cache_capacity)
  in
  let merges =
    convert_merges_to_merge_map cfg.vocab cfg.merges
      cfg.continuing_subword_prefix
  in
  {
    vocab = cfg.vocab;
    vocab_r;
    merges;
    cache;
    dropout = cfg.dropout;
    unk_token = cfg.unk_token;
    continuing_subword_prefix = cfg.continuing_subword_prefix;
    end_of_word_suffix = cfg.end_of_word_suffix;
    fuse_unk = cfg.fuse_unk;
    byte_fallback = cfg.byte_fallback;
    ignore_merges = cfg.ignore_merges;
  }

let read_files ~vocab_file ~merges_file =
  let vocab_json =
    let ic = open_in vocab_file in
    let content = really_input_string ic (in_channel_length ic) in
    close_in ic;
    Yojson.Basic.from_string content
  in
  let vocab = Hashtbl.create 1024 in
  (match vocab_json with
  | `Assoc items ->
      List.iter
        (fun (k, v) ->
          match v with
          | `Int id -> Hashtbl.add vocab k id
          | `Float f -> Hashtbl.add vocab k (int_of_float f)
          | _ -> failwith "Invalid vocab format")
        items
  | _ -> failwith "Invalid vocab.json format");
  let merges =
    let ic = open_in merges_file in
    let merges = ref [] in
    (try
       while true do
         let line = input_line ic in
         (* Skip empty lines and comment lines that start with #version *)
         if
           String.length line > 0
           && not (String.starts_with ~prefix:"#version" line)
         then
           match String.split_on_char ' ' line with
           | [ a; b ] -> merges := (a, b) :: !merges
           | _ -> failwith (Printf.sprintf "Invalid merge line: %s" line)
       done
     with End_of_file -> ());
    close_in ic;
    List.rev !merges
  in
  (vocab, merges)

let from_files ~vocab_file ~merges_file =
  let vocab, merges = read_files ~vocab_file ~merges_file in
  create
    {
      vocab;
      merges;
      cache_capacity = 10000;
      dropout = None;
      unk_token = None;
      continuing_subword_prefix = None;
      end_of_word_suffix = None;
      fuse_unk = false;
      byte_fallback = false;
      ignore_merges = false;
    }

let default () =
  create
    {
      vocab = Hashtbl.create 0;
      merges = [];
      cache_capacity = 10000;
      dropout = None;
      unk_token = None;
      continuing_subword_prefix = None;
      end_of_word_suffix = None;
      fuse_unk = false;
      byte_fallback = false;
      ignore_merges = false;
    }

let save model ~path ?name () =
  let vocab_file =
    match name with
    | Some n -> Filename.concat path (Printf.sprintf "%s-vocab.json" n)
    | None -> Filename.concat path "vocab.json"
  in
  let merges_file =
    match name with
    | Some n -> Filename.concat path (Printf.sprintf "%s-merges.txt" n)
    | None -> Filename.concat path "merges.txt"
  in
  let vocab_items =
    Hashtbl.fold
      (fun k v acc -> (k, (`Int v : Yojson.Basic.t)) :: acc)
      model.vocab []
    |> List.sort (fun (_, a) (_, b) ->
           match (a, b) with `Int x, `Int y -> compare x y | _ -> 0)
  in
  let vocab_json = `Assoc vocab_items in
  let oc = open_out vocab_file in
  output_string oc (Yojson.Basic.to_string vocab_json);
  close_out oc;
  let oc = open_out merges_file in
  output_string oc "#version: 0.2\n";
  let merges_list =
    IntPairMap.fold
      (fun (a_id, b_id) (rank, _) acc ->
        match
          ( Hashtbl.find_opt model.vocab_r a_id,
            Hashtbl.find_opt model.vocab_r b_id )
        with
        | Some a, Some b -> (rank, a, b) :: acc
        | _ -> acc)
      model.merges []
    |> List.sort (fun (r1, _, _) (r2, _, _) -> compare r1 r2)
  in
  List.iter (fun (_, a, b) -> Printf.fprintf oc "%s %s\n" a b) merges_list;
  close_out oc

let create_internal = create

module Builder = struct
  type builder = {
    mutable vocab : vocab;
    mutable merges : merges;
    mutable cache_capacity : int;
    mutable dropout : float option;
    mutable unk_token : string option;
    mutable continuing_subword_prefix : string option;
    mutable end_of_word_suffix : string option;
    mutable fuse_unk : bool;
    mutable byte_fallback : bool;
    mutable ignore_merges : bool;
  }

  let create () =
    {
      vocab = Hashtbl.create 0;
      merges = [];
      cache_capacity = 10000;
      dropout = None;
      unk_token = None;
      continuing_subword_prefix = None;
      end_of_word_suffix = None;
      fuse_unk = false;
      byte_fallback = false;
      ignore_merges = false;
    }

  let vocab_and_merges builder vocab merges =
    builder.vocab <- vocab;
    builder.merges <- merges;
    builder

  let cache_capacity builder capacity =
    builder.cache_capacity <- capacity;
    builder

  let dropout builder p =
    if p < 0.0 || p > 1.0 then failwith "Dropout must be between 0.0 and 1.0";
    builder.dropout <- Some p;
    builder

  let unk_token builder token =
    builder.unk_token <- Some token;
    builder

  let continuing_subword_prefix builder prefix =
    builder.continuing_subword_prefix <- Some prefix;
    builder

  let end_of_word_suffix builder suffix =
    builder.end_of_word_suffix <- Some suffix;
    builder

  let fuse_unk builder fuse =
    builder.fuse_unk <- fuse;
    builder

  let byte_fallback builder fallback =
    builder.byte_fallback <- fallback;
    builder

  let ignore_merges builder ignore =
    builder.ignore_merges <- ignore;
    builder

  let build b =
    create_internal
      {
        vocab = b.vocab;
        merges = b.merges;
        cache_capacity = b.cache_capacity;
        dropout = b.dropout;
        unk_token = b.unk_token;
        continuing_subword_prefix = b.continuing_subword_prefix;
        end_of_word_suffix = b.end_of_word_suffix;
        fuse_unk = b.fuse_unk;
        byte_fallback = b.byte_fallback;
        ignore_merges = b.ignore_merges;
      }
end

module Trainer = struct
  type word_count = (string, int) Hashtbl.t

  type trainer_config = {
    min_frequency : int;
    vocab_size : int;
    show_progress : bool;
    special_tokens : string list;
    limit_alphabet : int option;
    initial_alphabet : char list;
    continuing_subword_prefix : string option;
    end_of_word_suffix : string option;
    max_token_length : int option;
  }

  type trainer = { config : trainer_config; words : word_count }

  let default_config =
    {
      min_frequency = 0;
      vocab_size = 30000;
      show_progress = true;
      special_tokens = [];
      limit_alphabet = None;
      initial_alphabet = [];
      continuing_subword_prefix = None;
      end_of_word_suffix = None;
      max_token_length = None;
    }

  let create config = { config; words = Hashtbl.create 10000 }

  let feed trainer texts =
    List.iter
      (fun text ->
        let words = String.split_on_char ' ' text in
        List.iter
          (fun word ->
            if String.length word > 0 then
              Hashtbl.replace trainer.words word
                (1 + try Hashtbl.find trainer.words word with Not_found -> 0))
          words)
      texts

  let compute_pair_counts words_copy =
    let pair_counts = Hashtbl.create 10000 in
    Hashtbl.iter
      (fun word count ->
        let chars = String.split_on_char ' ' word in
        for i = 0 to List.length chars - 2 do
          let a = List.nth chars i in
          let b = List.nth chars (i + 1) in
          let pair = (a, b) in
          Hashtbl.replace pair_counts pair
            (count + try Hashtbl.find pair_counts pair with Not_found -> 0)
        done)
      words_copy;
    pair_counts

  let train trainer _model =
    let vocab = Hashtbl.create 10000 in
    let vocab_size = ref 0 in
    List.iter
      (fun token ->
        if not (Hashtbl.mem vocab token) then (
          Hashtbl.add vocab token !vocab_size;
          incr vocab_size))
      trainer.config.special_tokens;
    let alphabet = Hashtbl.create 10000 in
    Hashtbl.iter
      (fun word count ->
        let decoder = Uutf.decoder (`String word) in
        let rec loop () =
          match Uutf.decode decoder with
          | `Uchar u ->
              let buf = Buffer.create 4 in
              Uutf.Buffer.add_utf_8 buf u;
              let char_str = Buffer.contents buf in
              Hashtbl.replace alphabet char_str
                (count
                + try Hashtbl.find alphabet char_str with Not_found -> 0);
              loop ()
          | `End -> ()
          | _ -> loop ()
        in
        loop ())
      trainer.words;
    List.iter
      (fun c ->
        let char_str = String.make 1 c in
        Hashtbl.replace alphabet char_str max_int)
      trainer.config.initial_alphabet;
    let kept = Hashtbl.fold (fun k v acc -> (k, v) :: acc) alphabet [] in
    let kept = List.sort (fun (_, v1) (_, v2) -> compare v1 v2) kept in
    let to_remove =
      match trainer.config.limit_alphabet with
      | Some limit -> max 0 (List.length kept - limit)
      | None -> 0
    in
    let kept = List.drop to_remove kept in
    let kept = List.sort (fun (k1, _) (k2, _) -> compare k1 k2) kept in
    List.iter
      (fun (c, _) ->
        if not (Hashtbl.mem vocab c) then (
          Hashtbl.add vocab c !vocab_size;
          incr vocab_size))
      kept;
    let merges = ref [] in
    let words_copy = ref (Hashtbl.create (Hashtbl.length trainer.words)) in
    Hashtbl.iter
      (fun word count ->
        let decoder = Uutf.decoder (`String word) in
        let chars = ref [] in
        let rec loop () =
          match Uutf.decode decoder with
          | `Uchar u ->
              let buf = Buffer.create 4 in
              Uutf.Buffer.add_utf_8 buf u;
              chars := Buffer.contents buf :: !chars;
              loop ()
          | `End -> ()
          | _ -> loop ()
        in
        loop ();
        let separated = String.concat " " (List.rev !chars) in
        Hashtbl.add !words_copy separated count)
      trainer.words;
    while !vocab_size < trainer.config.vocab_size do
      let pair_counts = compute_pair_counts !words_copy in
      let best_pair = ref None in
      let best_count = ref (-1) in
      let best_pair_tie = ref ("", "") in
      Hashtbl.iter
        (fun pair count ->
          if count > !best_count then (
            best_count := count;
            best_pair := Some pair;
            best_pair_tie := pair)
          else if count = !best_count then
            if compare pair !best_pair_tie < 0 then best_pair_tie := pair)
        pair_counts;
      match !best_pair with
      | None -> vocab_size := trainer.config.vocab_size
      | Some (a, b) ->
          if !best_count < trainer.config.min_frequency then
            vocab_size := trainer.config.vocab_size
          else
            let new_token = a ^ b in
            let skip =
              match trainer.config.max_token_length with
              | Some l when String.length new_token > l -> true
              | _ -> false
            in
            if not skip then (
              if not (Hashtbl.mem vocab new_token) then (
                Hashtbl.add vocab new_token !vocab_size;
                incr vocab_size);
              merges := (a, b) :: !merges;
              let new_words = Hashtbl.create (Hashtbl.length !words_copy) in
              Hashtbl.iter
                (fun word count ->
                  let merged =
                    Str.global_replace
                      (Str.regexp_string (a ^ " " ^ b))
                      new_token word
                  in
                  Hashtbl.add new_words merged count)
                !words_copy;
              words_copy := new_words)
    done;
    let bpe_config : config =
      {
        vocab;
        merges = List.rev !merges;
        cache_capacity = 10000;
        dropout = None;
        unk_token = None;
        continuing_subword_prefix = trainer.config.continuing_subword_prefix;
        end_of_word_suffix = trainer.config.end_of_word_suffix;
        fuse_unk = false;
        byte_fallback = false;
        ignore_merges = false;
      }
    in
    let _trained_model = create_internal bpe_config in
    trainer.config.special_tokens
end