Source file encoding.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
(** Encoding module - represents the output of a tokenizer *)
type t = {
ids : int array;
type_ids : int array;
tokens : string array;
words : int option array;
offsets : (int * int) array;
special_tokens_mask : int array;
attention_mask : int array;
mutable overflowing : t list;
sequence_ranges : (int, int * int) Hashtbl.t;
}
(** Internal type definition *)
(** Truncation direction *)
type truncation_direction = Left | Right
(** Padding direction *)
type padding_direction = Left | Right
(** Create a new encoding *)
let create ~ids ~type_ids ~tokens ~words ~offsets ~special_tokens_mask
~attention_mask ~overflowing ~sequence_ranges =
{
ids;
type_ids;
tokens;
words;
offsets;
special_tokens_mask;
attention_mask;
overflowing;
sequence_ranges;
}
(** Create an empty encoding with given capacity *)
let with_capacity len =
{
ids = Array.make len 0;
type_ids = Array.make len 0;
tokens = Array.make len "";
words = Array.make len None;
offsets = Array.make len (0, 0);
special_tokens_mask = Array.make len 0;
attention_mask = Array.make len 0;
overflowing = [];
sequence_ranges = Hashtbl.create 1;
}
(** Create encoding from tokens *)
let from_tokens tokens ~type_id =
let length = List.length tokens in
let ids = Array.make length 0 in
let token_strs = Array.make length "" in
let offsets = Array.make length (0, 0) in
List.iteri
(fun i (id, token, offset) ->
ids.(i) <- id;
token_strs.(i) <- token;
offsets.(i) <- offset)
tokens;
{
ids;
tokens = token_strs;
offsets;
words = Array.make length None;
type_ids = Array.make length type_id;
attention_mask = Array.make length 1;
special_tokens_mask = Array.make length 0;
overflowing = [];
sequence_ranges = Hashtbl.create 1;
}
(** Check if encoding is empty *)
let is_empty t = Array.length t.ids = 0
(** Get the length of the encoding *)
let length t = Array.length t.ids
(** Get the number of sequences in the encoding *)
let n_sequences t =
if Hashtbl.length t.sequence_ranges = 0 then 1
else Hashtbl.length t.sequence_ranges
(** Set sequence id for the whole encoding *)
let set_sequence_id t sequence_id =
let new_ranges = Hashtbl.copy t.sequence_ranges in
Hashtbl.replace new_ranges sequence_id (0, length t);
{ t with sequence_ranges = new_ranges }
(** Getters *)
let get_ids t = t.ids
let get_type_ids t = t.type_ids
let get_tokens t = t.tokens
let get_word_ids t = t.words
let get_offsets t = t.offsets
let get_special_tokens_mask t = t.special_tokens_mask
let get_attention_mask t = t.attention_mask
let get_overflowing t = t.overflowing
(** Set type IDs *)
let set_type_ids t type_ids = { t with type_ids }
(** Set overflowing *)
let set_overflowing t overflowing = { t with overflowing }
(** Take overflowing *)
let take_overflowing t =
let overflowing = t.overflowing in
t.overflowing <- [];
(t, overflowing)
(** Get sequence IDs for each token *)
let get_sequence_ids t =
let sequences = Array.make (length t) None in
for seq_id = 0 to n_sequences t - 1 do
match Hashtbl.find_opt t.sequence_ranges seq_id with
| Some (start, stop) ->
for i = start to stop - 1 do
if i < Array.length sequences then sequences.(i) <- Some seq_id
done
| None -> ()
done;
sequences
(** Get the range for a given sequence *)
let sequence_range t sequence_id =
match Hashtbl.find_opt t.sequence_ranges sequence_id with
| Some range -> range
| None -> (0, length t)
(** Get the sequence index containing the given token *)
let token_to_sequence t token =
if token >= length t then None
else if Hashtbl.length t.sequence_ranges = 0 then Some 0
else
Hashtbl.fold
(fun seq_id (start, stop) acc ->
match acc with
| Some _ -> acc
| None -> if token >= start && token < stop then Some seq_id else None)
t.sequence_ranges None
(** Get the tokens corresponding to the given word *)
let word_to_tokens t ~word ~sequence_id =
let start_ref = ref None in
let end_ref = ref None in
let seq_start, seq_end = sequence_range t sequence_id in
for i = seq_start to seq_end - 1 do
match t.words.(i) with
| Some w when w = word ->
if !start_ref = None || Some i < !start_ref then start_ref := Some i;
if !end_ref = None || Some i >= !end_ref then end_ref := Some (i + 1)
| _ -> ()
done;
match (!start_ref, !end_ref) with
| Some start, Some end_pos -> Some (start, end_pos)
| _ -> None
(** Get the character offsets of the given word *)
let word_to_chars t ~word ~sequence_id =
match word_to_tokens t ~word ~sequence_id with
| Some (start, end_pos) when end_pos > 0 ->
let start_offset, _ = t.offsets.(start) in
let _, end_offset = t.offsets.(end_pos - 1) in
Some (start_offset, end_offset)
| _ -> None
(** Get the character offsets of the given token *)
let token_to_chars t token =
match token_to_sequence t token with
| Some seq_id when token < Array.length t.offsets ->
Some (seq_id, t.offsets.(token))
| _ -> None
(** Get the word containing the given token *)
let token_to_word t token =
match token_to_sequence t token with
| Some seq_id -> (
match t.words.(token) with
| Some word -> Some (seq_id, word)
| None -> None)
| None -> None
(** Get the token containing the given character position *)
let char_to_token t ~pos ~sequence_id =
let seq_start, seq_end = sequence_range t sequence_id in
let rec find_token i =
if i >= seq_end then None
else
let start_offset, end_offset = t.offsets.(i) in
if pos >= start_offset && pos < end_offset then Some i
else find_token (i + 1)
in
find_token seq_start
(** Get the word containing the given character position *)
let char_to_word t ~pos ~sequence_id =
match char_to_token t ~pos ~sequence_id with
| Some token -> Option.bind t.words.(token) (fun w -> Some w)
| None -> None
(** Helper to copy array slice *)
let array_slice arr start stop =
Array.init (stop - start) (fun i -> arr.(start + i))
(** Truncate the encoding *)
let truncate t ~max_length ~stride ~(direction : truncation_direction) =
let encoding_len = length t in
if max_length >= encoding_len then t
else if max_length = 0 then (
let empty = with_capacity 0 in
empty.overflowing <- [ t ];
empty)
else (
assert (stride < max_length);
Hashtbl.clear t.sequence_ranges;
let offset = max_length - stride in
let parts_ranges =
match direction with
| Right ->
let rec collect start acc =
if start >= encoding_len then List.rev acc
else
let stop = min (start + max_length) encoding_len in
collect (start + offset) ((start, stop) :: acc)
in
collect 0 []
| Left ->
let rec collect stop acc =
if stop <= 0 then acc
else
let start = max 0 (stop - max_length) in
collect (stop - offset) ((start, stop) :: acc)
in
collect encoding_len []
in
match parts_ranges with
| [] -> with_capacity 0
| (start, stop) :: rest ->
let new_encoding =
{
ids = array_slice t.ids start stop;
type_ids = array_slice t.type_ids start stop;
tokens = array_slice t.tokens start stop;
words = array_slice t.words start stop;
offsets = array_slice t.offsets start stop;
special_tokens_mask = array_slice t.special_tokens_mask start stop;
attention_mask = array_slice t.attention_mask start stop;
overflowing = [];
sequence_ranges = Hashtbl.create 1;
}
in
new_encoding.overflowing <-
List.map
(fun (start, stop) ->
{
ids = array_slice t.ids start stop;
type_ids = array_slice t.type_ids start stop;
tokens = array_slice t.tokens start stop;
words = array_slice t.words start stop;
offsets = array_slice t.offsets start stop;
special_tokens_mask =
array_slice t.special_tokens_mask start stop;
attention_mask = array_slice t.attention_mask start stop;
overflowing = [];
sequence_ranges = Hashtbl.create 1;
})
rest;
new_encoding)
(** Merge multiple encodings *)
let rec merge encodings ~growing_offsets =
let rec merge_list acc = function
| [] -> acc
| e :: rest -> merge_list (merge_with acc e ~growing_offsets) rest
in
match encodings with [] -> with_capacity 0 | e :: rest -> merge_list e rest
(** Merge with another encoding *)
and merge_with t1 t2 ~growing_offsets =
let original_len = length t1 in
let new_overflowing = ref [] in
List.iter
(fun o1 ->
new_overflowing := merge_with o1 t2 ~growing_offsets :: !new_overflowing;
List.iter
(fun o2 ->
new_overflowing :=
merge_with o1 o2 ~growing_offsets :: !new_overflowing)
t2.overflowing)
t1.overflowing;
List.iter
(fun o2 ->
new_overflowing := merge_with t1 o2 ~growing_offsets :: !new_overflowing)
t2.overflowing;
let new_ranges = Hashtbl.copy t1.sequence_ranges in
Hashtbl.iter
(fun seq_id (start, stop) ->
Hashtbl.replace new_ranges seq_id
(original_len + start, original_len + stop))
t2.sequence_ranges;
let starting_offset =
if growing_offsets && Array.length t1.offsets > 0 then
snd t1.offsets.(Array.length t1.offsets - 1)
else 0
in
let merged_offsets =
if growing_offsets then
Array.map
(fun (start, stop) -> (start + starting_offset, stop + starting_offset))
t2.offsets
else t2.offsets
in
{
ids = Array.append t1.ids t2.ids;
type_ids = Array.append t1.type_ids t2.type_ids;
tokens = Array.append t1.tokens t2.tokens;
words = Array.append t1.words t2.words;
offsets = Array.append t1.offsets merged_offsets;
special_tokens_mask =
Array.append t1.special_tokens_mask t2.special_tokens_mask;
attention_mask = Array.append t1.attention_mask t2.attention_mask;
overflowing = List.rev !new_overflowing;
sequence_ranges = new_ranges;
}
(** Pad the encoding *)
let rec pad t ~target_length ~pad_id ~pad_type_id ~pad_token ~direction =
let padded_overflowing =
List.map
(fun e -> pad e ~target_length ~pad_id ~pad_type_id ~pad_token ~direction)
t.overflowing
in
let current_len = length t in
if current_len >= target_length then
{ t with overflowing = padded_overflowing }
else
let pad_length = target_length - current_len in
match direction with
| Left ->
let new_ranges = Hashtbl.create (Hashtbl.length t.sequence_ranges) in
Hashtbl.iter
(fun seq_id (start, stop) ->
Hashtbl.add new_ranges seq_id (start + pad_length, stop + pad_length))
t.sequence_ranges;
{
ids = Array.append (Array.make pad_length pad_id) t.ids;
type_ids = Array.append (Array.make pad_length pad_type_id) t.type_ids;
tokens = Array.append (Array.make pad_length pad_token) t.tokens;
words = Array.append (Array.make pad_length None) t.words;
offsets = Array.append (Array.make pad_length (0, 0)) t.offsets;
special_tokens_mask =
Array.append (Array.make pad_length 1) t.special_tokens_mask;
attention_mask =
Array.append (Array.make pad_length 0) t.attention_mask;
overflowing = padded_overflowing;
sequence_ranges = new_ranges;
}
| Right ->
{
ids = Array.append t.ids (Array.make pad_length pad_id);
type_ids = Array.append t.type_ids (Array.make pad_length pad_type_id);
tokens = Array.append t.tokens (Array.make pad_length pad_token);
words = Array.append t.words (Array.make pad_length None);
offsets = Array.append t.offsets (Array.make pad_length (0, 0));
special_tokens_mask =
Array.append t.special_tokens_mask (Array.make pad_length 1);
attention_mask =
Array.append t.attention_mask (Array.make pad_length 0);
overflowing = padded_overflowing;
sequence_ranges = t.sequence_ranges;
}
type encoding_data = {
ids : int array;
type_ids : int array;
tokens : string array;
offsets : (int * int) array;
attention_mask : int array;
special_tokens_mask : int array;
overflowing : t list;
word_ids : int option array;
sequence_ids : int option array;
n_sequences : int;
}
(** Encoding data for serialization - currently unused but may be needed for
JSON serialization *)
let _to_data (t : t) : encoding_data =
{
ids = t.ids;
type_ids = t.type_ids;
tokens = t.tokens;
offsets = t.offsets;
attention_mask = t.attention_mask;
special_tokens_mask = t.special_tokens_mask;
overflowing = t.overflowing;
word_ids = t.words;
sequence_ids = get_sequence_ids t;
n_sequences = n_sequences t;
}
let _from_data (d : encoding_data) : t =
let t =
{
ids = d.ids;
type_ids = d.type_ids;
tokens = d.tokens;
words = d.word_ids;
offsets = d.offsets;
special_tokens_mask = d.special_tokens_mask;
attention_mask = d.attention_mask;
overflowing = d.overflowing;
sequence_ranges = Hashtbl.create 1;
}
in
if d.n_sequences > 1 then (
let current_seq = ref None in
let start = ref 0 in
Array.iteri
(fun i seq_opt ->
match (seq_opt, !current_seq) with
| Some seq, None ->
current_seq := Some seq;
start := i
| Some seq, Some curr_seq when seq <> curr_seq ->
Hashtbl.add t.sequence_ranges curr_seq (!start, i);
current_seq := Some seq;
start := i
| None, Some curr_seq ->
Hashtbl.add t.sequence_ranges curr_seq (!start, i);
current_seq := None
| _ -> ())
d.sequence_ids;
match !current_seq with
| Some seq ->
Hashtbl.add t.sequence_ranges seq (!start, Array.length d.sequence_ids)
| None -> ());
t