package saga
Text processing and NLP extensions for Nx
Install
dune-project
Dependency
Authors
Maintainers
Sources
raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c
doc/src/saga.tokenizers/encoding.ml.html
Source file encoding.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
(** Encoding module - represents the output of a tokenizer *) type t = { ids : int array; type_ids : int array; tokens : string array; words : int option array; offsets : (int * int) array; special_tokens_mask : int array; attention_mask : int array; mutable overflowing : t list; sequence_ranges : (int, int * int) Hashtbl.t; } (** Internal type definition *) (** Truncation direction *) type truncation_direction = Left | Right (** Padding direction *) type padding_direction = Left | Right (** Create a new encoding *) let create ~ids ~type_ids ~tokens ~words ~offsets ~special_tokens_mask ~attention_mask ~overflowing ~sequence_ranges = { ids; type_ids; tokens; words; offsets; special_tokens_mask; attention_mask; overflowing; sequence_ranges; } (** Create an empty encoding with given capacity *) let with_capacity len = { ids = Array.make len 0; type_ids = Array.make len 0; tokens = Array.make len ""; words = Array.make len None; offsets = Array.make len (0, 0); special_tokens_mask = Array.make len 0; attention_mask = Array.make len 0; overflowing = []; sequence_ranges = Hashtbl.create 1; } (** Create encoding from tokens *) let from_tokens tokens ~type_id = let length = List.length tokens in let ids = Array.make length 0 in let token_strs = Array.make length "" in let offsets = Array.make length (0, 0) in List.iteri (fun i (id, token, offset) -> ids.(i) <- id; token_strs.(i) <- token; offsets.(i) <- offset) tokens; { ids; tokens = token_strs; offsets; words = Array.make length None; type_ids = Array.make length type_id; attention_mask = Array.make length 1; special_tokens_mask = Array.make length 0; overflowing = []; sequence_ranges = Hashtbl.create 1; } (** Check if encoding is empty *) let is_empty t = Array.length t.ids = 0 (** Get the length of the encoding *) let length t = Array.length t.ids (** Get the number of sequences in the encoding *) let n_sequences t = if Hashtbl.length t.sequence_ranges = 0 then 1 else Hashtbl.length t.sequence_ranges (** Set sequence id for the whole encoding *) let set_sequence_id t sequence_id = let new_ranges = Hashtbl.copy t.sequence_ranges in Hashtbl.replace new_ranges sequence_id (0, length t); { t with sequence_ranges = new_ranges } (** Getters *) let get_ids t = t.ids let get_type_ids t = t.type_ids let get_tokens t = t.tokens let get_word_ids t = t.words let get_offsets t = t.offsets let get_special_tokens_mask t = t.special_tokens_mask let get_attention_mask t = t.attention_mask let get_overflowing t = t.overflowing (** Set type IDs *) let set_type_ids t type_ids = { t with type_ids } (** Set overflowing *) let set_overflowing t overflowing = { t with overflowing } (** Take overflowing *) let take_overflowing t = let overflowing = t.overflowing in t.overflowing <- []; (t, overflowing) (** Get sequence IDs for each token *) let get_sequence_ids t = let sequences = Array.make (length t) None in for seq_id = 0 to n_sequences t - 1 do match Hashtbl.find_opt t.sequence_ranges seq_id with | Some (start, stop) -> for i = start to stop - 1 do if i < Array.length sequences then sequences.(i) <- Some seq_id done | None -> () done; sequences (** Get the range for a given sequence *) let sequence_range t sequence_id = match Hashtbl.find_opt t.sequence_ranges sequence_id with | Some range -> range | None -> (0, length t) (** Get the sequence index containing the given token *) let token_to_sequence t token = if token >= length t then None else if Hashtbl.length t.sequence_ranges = 0 then Some 0 else Hashtbl.fold (fun seq_id (start, stop) acc -> match acc with | Some _ -> acc | None -> if token >= start && token < stop then Some seq_id else None) t.sequence_ranges None (** Get the tokens corresponding to the given word *) let word_to_tokens t ~word ~sequence_id = let start_ref = ref None in let end_ref = ref None in let seq_start, seq_end = sequence_range t sequence_id in for i = seq_start to seq_end - 1 do match t.words.(i) with | Some w when w = word -> if !start_ref = None || Some i < !start_ref then start_ref := Some i; if !end_ref = None || Some i >= !end_ref then end_ref := Some (i + 1) | _ -> () done; match (!start_ref, !end_ref) with | Some start, Some end_pos -> Some (start, end_pos) | _ -> None (** Get the character offsets of the given word *) let word_to_chars t ~word ~sequence_id = match word_to_tokens t ~word ~sequence_id with | Some (start, end_pos) when end_pos > 0 -> let start_offset, _ = t.offsets.(start) in let _, end_offset = t.offsets.(end_pos - 1) in Some (start_offset, end_offset) | _ -> None (** Get the character offsets of the given token *) let token_to_chars t token = match token_to_sequence t token with | Some seq_id when token < Array.length t.offsets -> Some (seq_id, t.offsets.(token)) | _ -> None (** Get the word containing the given token *) let token_to_word t token = match token_to_sequence t token with | Some seq_id -> ( match t.words.(token) with | Some word -> Some (seq_id, word) | None -> None) | None -> None (** Get the token containing the given character position *) let char_to_token t ~pos ~sequence_id = let seq_start, seq_end = sequence_range t sequence_id in let rec find_token i = if i >= seq_end then None else let start_offset, end_offset = t.offsets.(i) in if pos >= start_offset && pos < end_offset then Some i else find_token (i + 1) in find_token seq_start (** Get the word containing the given character position *) let char_to_word t ~pos ~sequence_id = match char_to_token t ~pos ~sequence_id with | Some token -> Option.bind t.words.(token) (fun w -> Some w) | None -> None (** Helper to copy array slice *) let array_slice arr start stop = Array.init (stop - start) (fun i -> arr.(start + i)) (** Truncate the encoding *) let truncate t ~max_length ~stride ~(direction : truncation_direction) = let encoding_len = length t in if max_length >= encoding_len then t else if max_length = 0 then ( (* Move everything to overflowing *) let empty = with_capacity 0 in empty.overflowing <- [ t ]; empty) else ( assert (stride < max_length); (* Clear sequence ranges when truncating *) Hashtbl.clear t.sequence_ranges; let offset = max_length - stride in (* Calculate parts ranges *) let parts_ranges = match direction with | Right -> let rec collect start acc = if start >= encoding_len then List.rev acc else let stop = min (start + max_length) encoding_len in collect (start + offset) ((start, stop) :: acc) in collect 0 [] | Left -> let rec collect stop acc = if stop <= 0 then acc else let start = max 0 (stop - max_length) in collect (stop - offset) ((start, stop) :: acc) in collect encoding_len [] in match parts_ranges with | [] -> with_capacity 0 | (start, stop) :: rest -> (* Create main encoding from first part *) let new_encoding = { ids = array_slice t.ids start stop; type_ids = array_slice t.type_ids start stop; tokens = array_slice t.tokens start stop; words = array_slice t.words start stop; offsets = array_slice t.offsets start stop; special_tokens_mask = array_slice t.special_tokens_mask start stop; attention_mask = array_slice t.attention_mask start stop; overflowing = []; sequence_ranges = Hashtbl.create 1; } in (* Create overflowing encodings *) new_encoding.overflowing <- List.map (fun (start, stop) -> { ids = array_slice t.ids start stop; type_ids = array_slice t.type_ids start stop; tokens = array_slice t.tokens start stop; words = array_slice t.words start stop; offsets = array_slice t.offsets start stop; special_tokens_mask = array_slice t.special_tokens_mask start stop; attention_mask = array_slice t.attention_mask start stop; overflowing = []; sequence_ranges = Hashtbl.create 1; }) rest; new_encoding) (** Merge multiple encodings *) let rec merge encodings ~growing_offsets = let rec merge_list acc = function | [] -> acc | e :: rest -> merge_list (merge_with acc e ~growing_offsets) rest in match encodings with [] -> with_capacity 0 | e :: rest -> merge_list e rest (** Merge with another encoding *) and merge_with t1 t2 ~growing_offsets = let original_len = length t1 in (* Merge overflowing encodings *) let new_overflowing = ref [] in (* Merge t1's overflowing with t2 *) List.iter (fun o1 -> new_overflowing := merge_with o1 t2 ~growing_offsets :: !new_overflowing; List.iter (fun o2 -> new_overflowing := merge_with o1 o2 ~growing_offsets :: !new_overflowing) t2.overflowing) t1.overflowing; (* Merge t1 with t2's overflowing *) List.iter (fun o2 -> new_overflowing := merge_with t1 o2 ~growing_offsets :: !new_overflowing) t2.overflowing; (* Update sequence ranges *) let new_ranges = Hashtbl.copy t1.sequence_ranges in Hashtbl.iter (fun seq_id (start, stop) -> Hashtbl.replace new_ranges seq_id (original_len + start, original_len + stop)) t2.sequence_ranges; (* Calculate offset for growing offsets *) let starting_offset = if growing_offsets && Array.length t1.offsets > 0 then snd t1.offsets.(Array.length t1.offsets - 1) else 0 in (* Merge arrays *) let merged_offsets = if growing_offsets then Array.map (fun (start, stop) -> (start + starting_offset, stop + starting_offset)) t2.offsets else t2.offsets in { ids = Array.append t1.ids t2.ids; type_ids = Array.append t1.type_ids t2.type_ids; tokens = Array.append t1.tokens t2.tokens; words = Array.append t1.words t2.words; offsets = Array.append t1.offsets merged_offsets; special_tokens_mask = Array.append t1.special_tokens_mask t2.special_tokens_mask; attention_mask = Array.append t1.attention_mask t2.attention_mask; overflowing = List.rev !new_overflowing; sequence_ranges = new_ranges; } (** Pad the encoding *) let rec pad t ~target_length ~pad_id ~pad_type_id ~pad_token ~direction = (* First pad all overflowing encodings *) let padded_overflowing = List.map (fun e -> pad e ~target_length ~pad_id ~pad_type_id ~pad_token ~direction) t.overflowing in let current_len = length t in if current_len >= target_length then { t with overflowing = padded_overflowing } else let pad_length = target_length - current_len in match direction with | Left -> (* Update sequence ranges *) let new_ranges = Hashtbl.create (Hashtbl.length t.sequence_ranges) in Hashtbl.iter (fun seq_id (start, stop) -> Hashtbl.add new_ranges seq_id (start + pad_length, stop + pad_length)) t.sequence_ranges; { ids = Array.append (Array.make pad_length pad_id) t.ids; type_ids = Array.append (Array.make pad_length pad_type_id) t.type_ids; tokens = Array.append (Array.make pad_length pad_token) t.tokens; words = Array.append (Array.make pad_length None) t.words; offsets = Array.append (Array.make pad_length (0, 0)) t.offsets; special_tokens_mask = Array.append (Array.make pad_length 1) t.special_tokens_mask; attention_mask = Array.append (Array.make pad_length 0) t.attention_mask; overflowing = padded_overflowing; sequence_ranges = new_ranges; } | Right -> { ids = Array.append t.ids (Array.make pad_length pad_id); type_ids = Array.append t.type_ids (Array.make pad_length pad_type_id); tokens = Array.append t.tokens (Array.make pad_length pad_token); words = Array.append t.words (Array.make pad_length None); offsets = Array.append t.offsets (Array.make pad_length (0, 0)); special_tokens_mask = Array.append t.special_tokens_mask (Array.make pad_length 1); attention_mask = Array.append t.attention_mask (Array.make pad_length 0); overflowing = padded_overflowing; sequence_ranges = t.sequence_ranges; } type encoding_data = { ids : int array; type_ids : int array; tokens : string array; offsets : (int * int) array; attention_mask : int array; special_tokens_mask : int array; overflowing : t list; word_ids : int option array; sequence_ids : int option array; n_sequences : int; } (** Encoding data for serialization - currently unused but may be needed for JSON serialization *) let _to_data (t : t) : encoding_data = { ids = t.ids; type_ids = t.type_ids; tokens = t.tokens; offsets = t.offsets; attention_mask = t.attention_mask; special_tokens_mask = t.special_tokens_mask; overflowing = t.overflowing; word_ids = t.words; sequence_ids = get_sequence_ids t; n_sequences = n_sequences t; } let _from_data (d : encoding_data) : t = let t = { ids = d.ids; type_ids = d.type_ids; tokens = d.tokens; words = d.word_ids; offsets = d.offsets; special_tokens_mask = d.special_tokens_mask; attention_mask = d.attention_mask; overflowing = d.overflowing; sequence_ranges = Hashtbl.create 1; } in (* Reconstruct sequence ranges from sequence_ids if needed *) if d.n_sequences > 1 then ( let current_seq = ref None in let start = ref 0 in Array.iteri (fun i seq_opt -> match (seq_opt, !current_seq) with | Some seq, None -> current_seq := Some seq; start := i | Some seq, Some curr_seq when seq <> curr_seq -> Hashtbl.add t.sequence_ranges curr_seq (!start, i); current_seq := Some seq; start := i | None, Some curr_seq -> Hashtbl.add t.sequence_ranges curr_seq (!start, i); current_seq := None | _ -> ()) d.sequence_ids; (* Handle the last sequence *) match !current_seq with | Some seq -> Hashtbl.add t.sequence_ranges seq (!start, Array.length d.sequence_ids) | None -> ()); t
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>