package saga

  1. Overview
  2. Docs
Text processing and NLP extensions for Nx

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha2.tbz
sha256=93abc49d075a1754442ccf495645bc4fdc83e4c66391ec8aca8fa15d2b4f44d2
sha512=5eb958c51f30ae46abded4c96f48d1825f79c7ce03f975f9a6237cdfed0d62c0b4a0774296694def391573d849d1f869919c49008acffca95946b818ad325f6f

doc/src/saga.tokenizers/unigram.ml.html

Source file unigram.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
(** Simple Unigram tokenizer implementation. *)

type vocab_entry = string * float
type token_map = (string, int) Hashtbl.t
type vocab = vocab_entry array
type t = { vocab : vocab; token_to_ids : token_map }

let create vocab_list =
  let vocab = Array.of_list vocab_list in
  let token_to_ids = Hashtbl.create (Array.length vocab) in
  Array.iteri
    (fun idx (token, _) -> Hashtbl.replace token_to_ids token idx)
    vocab;
  { vocab; token_to_ids }

let token_to_id model token = Hashtbl.find_opt model.token_to_ids token

let id_to_token model id =
  if id >= 0 && id < Array.length model.vocab then
    let token, _ = model.vocab.(id) in
    Some token
  else None

let get_vocab model = Array.to_list model.vocab
let get_vocab_size model = Array.length model.vocab

let tokenize model text =
  let len = String.length text in
  let rec consume pos acc =
    if pos >= len then List.rev acc
    else if
      text.[pos] = ' '
      || text.[pos] = '\n'
      || text.[pos] = '\t'
      || text.[pos] = '\r'
    then consume (pos + 1) acc
    else
      let rec find_best_length length =
        if length = 0 then None
        else
          let s = String.sub text pos length in
          match token_to_id model s with
          | Some id -> Some (id, s, (pos, pos + length))
          | None -> find_best_length (length - 1)
      in
      match find_best_length (len - pos) with
      | Some token ->
          let _, _, (_, next_pos) = token in
          consume next_pos (token :: acc)
      | None ->
          let s = String.sub text pos 1 in
          let id = match token_to_id model s with Some id -> id | None -> 0 in
          consume (pos + 1) ((id, s, (pos, pos + 1)) :: acc)
  in
  consume 0 []

let save model ~folder () =
  let json_vocab =
    Array.to_list model.vocab
    |> List.mapi (fun id (token, prob) ->
           `Assoc
             [
               ("id", `Int id); ("token", `String token); ("prob", `Float prob);
             ])
  in
  let json =
    `Assoc [ ("type", `String "Unigram"); ("vocab", `List json_vocab) ]
  in
  let path = Filename.concat folder "unigram.json" in
  Yojson.Basic.to_file path json;
  [ "unigram.json" ]

let train ~vocab_size ~show_progress ~special_tokens ~shrinking_factor
    ~unk_token ~max_piece_length ~n_sub_iterations texts existing =
  let _ =
    ( show_progress,
      shrinking_factor,
      unk_token,
      max_piece_length,
      n_sub_iterations,
      existing )
  in
  let counts = Hashtbl.create 10000 in
  List.iter
    (fun line ->
      let words = Str.split (Str.regexp "[ \t\n\r]+") line in
      List.iter
        (fun word ->
          if word <> "" then
            Hashtbl.replace counts word
              (1 + Option.value ~default:0 (Hashtbl.find_opt counts word)))
        words)
    texts;

  let total =
    Hashtbl.fold (fun _ count acc -> acc + count) counts 0 |> float_of_int
  in
  let sorted =
    Hashtbl.fold (fun token count acc -> (token, count) :: acc) counts []
    |> List.sort (fun (_, c1) (_, c2) -> compare c2 c1)
  in

  let take_first n lst =
    let rec aux i = function
      | [] -> []
      | _ when i = 0 -> []
      | x :: xs -> x :: aux (i - 1) xs
    in
    aux n lst
  in

  let selected = take_first vocab_size sorted in
  let vocab_with_probs =
    special_tokens
    |> List.map (fun token -> (token, 1.0 /. float_of_int (vocab_size + 1)))
    |> fun specials ->
    specials
    @ List.map
        (fun (token, count) ->
          let prob = if total = 0. then 0. else float_of_int count /. total in
          (token, prob))
        selected
  in
  let model = create vocab_with_probs in
  (model, special_tokens)