package kaun

  1. Overview
  2. Docs

Source file attention.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
module Dtype = Nx_core.Dtype
module Ptree = Ptree
module Initializers = Initializers

let normalize_mask (type a) (type layout) (mask : (a, layout) Rune.t) :
    Rune.bool_t =
  let dtype = Rune.dtype mask in
  match Dtype.equal_witness dtype Rune.bool with
  | Some Type.Equal -> Rune.cast Rune.bool mask
  | None ->
      let zeros = Rune.zeros_like mask in
      Rune.not_equal mask zeros |> Rune.cast Rune.bool

let compute_attention_from_projected ?attention_mask ?(is_causal = false)
    ?dropout_rate ?dropout_rng ?scale ~q ~k ~v ~embed_dim ~num_heads
    ~num_kv_heads ~head_dim () =
  if embed_dim <> num_heads * head_dim then
    invalid_arg
      (Printf.sprintf
         "multi-head attention: embed_dim (%d) must equal num_heads (%d) * \
          head_dim (%d)"
         embed_dim num_heads head_dim);
  let reshape_heads tensor heads =
    let tensor = Rune.contiguous tensor in
    let shape = Rune.shape tensor in
    if Array.length shape <> 3 then
      invalid_arg "multi-head attention expects projected tensors of rank 3";
    let last_dim = shape.(2) in
    if last_dim <> heads * head_dim then
      invalid_arg
        (Printf.sprintf
           "multi-head attention: projected dimension mismatch (got %d, \
            expected %d)"
           last_dim (heads * head_dim));
    let reshaped =
      Rune.reshape [| shape.(0); shape.(1); heads; head_dim |] tensor
    in
    Rune.transpose reshaped ~axes:[ 0; 2; 1; 3 ]
  in
  let q_heads = reshape_heads q num_heads in
  let k_heads = reshape_heads k num_kv_heads in
  let v_heads = reshape_heads v num_kv_heads in
  let repeat_if_needed tensor =
    if num_kv_heads < num_heads then (
      if num_heads mod num_kv_heads <> 0 then
        invalid_arg
          (Printf.sprintf
             "multi-head attention: num_heads (%d) must be a multiple of \
              num_kv_heads (%d)"
             num_heads num_kv_heads);
      let repeat_factor = num_heads / num_kv_heads in
      let shape = Rune.shape tensor in
      let expanded = Rune.expand_dims [ 2 ] tensor in
      let target =
        [| shape.(0); shape.(1); repeat_factor; shape.(2); shape.(3) |]
      in
      let broadcasted = Rune.broadcast_to target expanded in
      Rune.reshape [| shape.(0); num_heads; shape.(2); shape.(3) |] broadcasted)
    else tensor
  in
  let k_heads = repeat_if_needed k_heads in
  let v_heads = repeat_if_needed v_heads in
  let q_shape = Rune.shape q in
  let batch = q_shape.(0) in
  let seq_len_q = q_shape.(1) in
  let seq_len_k = (Rune.shape k).(1) in
  let attention_mask =
    match attention_mask with
    | None -> None
    | Some mask ->
        let mask = normalize_mask mask in
        let shape = Rune.shape mask in
        let prepared =
          match Array.length shape with
          | 2 ->
              let batch_dim = shape.(0) in
              let key_dim = shape.(1) in
              if
                (batch_dim <> batch && batch_dim <> 1)
                || (key_dim <> seq_len_k && key_dim <> 1)
              then
                invalid_arg
                  "attention mask of rank 2 must align with [batch; seq_len_k]";
              Rune.reshape [| batch_dim; 1; 1; key_dim |] mask
          | 3 ->
              let batch_dim = shape.(0) in
              let query_dim = shape.(1) in
              let key_dim = shape.(2) in
              if
                (batch_dim <> batch && batch_dim <> 1)
                || (query_dim <> seq_len_q && query_dim <> 1)
                || (key_dim <> seq_len_k && key_dim <> 1)
              then
                invalid_arg
                  "attention mask of rank 3 must align with [batch; seq_len_q; \
                   seq_len_k]";
              Rune.expand_dims [ 1 ] mask
          | 4 ->
              let batch_dim = shape.(0) in
              let head_dim = shape.(1) in
              let query_dim = shape.(2) in
              let key_dim = shape.(3) in
              if
                (batch_dim <> batch && batch_dim <> 1)
                || (head_dim <> num_heads && head_dim <> 1)
                || (query_dim <> seq_len_q && query_dim <> 1)
                || (key_dim <> seq_len_k && key_dim <> 1)
              then
                invalid_arg
                  "attention mask of rank 4 must align with [batch; num_heads; \
                   seq_len_q; seq_len_k]";
              mask
          | _ ->
              invalid_arg
                "attention mask rank must be 2, 3, or 4 for multi-head \
                 attention"
        in
        let target = [| batch; num_heads; seq_len_q; seq_len_k |] in
        Some (Rune.broadcast_to target prepared)
  in
  let attn =
    let dropout_seed =
      match dropout_rng with
      | Some rng -> Some (Rune.Rng.to_int rng)
      | None when Option.is_some dropout_rate ->
          invalid_arg "attention dropout requires RNG"
      | None -> None
    in
    Rune.dot_product_attention ?attention_mask ?scale ?dropout_rate
      ?dropout_seed ~is_causal q_heads k_heads v_heads
  in
  attn
  |> Rune.transpose ~axes:[ 0; 2; 1; 3 ]
  |> Rune.contiguous
  |> Rune.reshape [| batch; seq_len_q; embed_dim |]

module Multi_head = struct
  type config = {
    embed_dim : int;
    num_heads : int;
    num_kv_heads : int option;
    head_dim : int option;
    dropout : float;
    use_qk_norm : bool;
    attn_logits_soft_cap : float option;
    query_pre_attn_scalar : float option;
  }

  let make_config ~embed_dim ~num_heads ?num_kv_heads ?head_dim ?(dropout = 0.0)
      ?(use_qk_norm = false) ?attn_logits_soft_cap ?query_pre_attn_scalar () =
    {
      embed_dim;
      num_heads;
      num_kv_heads;
      head_dim;
      dropout;
      use_qk_norm;
      attn_logits_soft_cap;
      query_pre_attn_scalar;
    }

  type params = Ptree.t

  let init config ~rngs ~dtype =
    let head_dim =
      Option.value config.head_dim ~default:(config.embed_dim / config.num_heads)
    in
    if head_dim * config.num_heads <> config.embed_dim then
      invalid_arg
        (Printf.sprintf
           "multi-head attention: embed_dim (%d) not divisible by num_heads \
            (%d)"
           config.embed_dim config.num_heads);
    let num_kv_heads =
      Option.value config.num_kv_heads ~default:config.num_heads
    in
    let num_keys = if config.use_qk_norm then 6 else 4 in
    let keys = Rune.Rng.split ~n:num_keys rngs in
    let init_fn = (Initializers.glorot_uniform ()).f in
    let q_proj =
      init_fn
        (Rune.Rng.to_int keys.(0))
        [| config.embed_dim; config.num_heads * head_dim |]
        dtype
    in
    let k_proj =
      init_fn
        (Rune.Rng.to_int keys.(1))
        [| config.embed_dim; num_kv_heads * head_dim |]
        dtype
    in
    let v_proj =
      init_fn
        (Rune.Rng.to_int keys.(2))
        [| config.embed_dim; num_kv_heads * head_dim |]
        dtype
    in
    let out_proj =
      init_fn
        (Rune.Rng.to_int keys.(3))
        [| config.num_heads * head_dim; config.embed_dim |]
        dtype
    in
    let base =
      [
        ("q_proj", Ptree.tensor q_proj);
        ("k_proj", Ptree.tensor k_proj);
        ("v_proj", Ptree.tensor v_proj);
        ("out_proj", Ptree.tensor out_proj);
      ]
    in
    let base =
      if config.use_qk_norm then
        let scale = Rune.ones dtype [| head_dim |] in
        base
        @ [
            ("q_norm_scale", Ptree.tensor scale);
            ("k_norm_scale", Ptree.tensor scale);
          ]
      else base
    in
    Ptree.dict base

  let apply ?rngs ?attention_mask config params ~training ~query ~key ~value =
    let dtype = Rune.dtype query in
    let fields = Ptree.Dict.fields_exn ~ctx:"attention.multi_head" params in
    let get name = Ptree.Dict.get_tensor_exn fields ~name dtype in
    let q_proj = get "q_proj" in
    let k_proj = get "k_proj" in
    let v_proj = get "v_proj" in
    let out_proj = get "out_proj" in
    let head_dim =
      Option.value config.head_dim ~default:(config.embed_dim / config.num_heads)
    in
    let num_kv_heads =
      Option.value config.num_kv_heads ~default:config.num_heads
    in
    let scale : float =
      match config.query_pre_attn_scalar with
      | Some s -> s
      | None -> 1.0 /. Stdlib.sqrt (float_of_int head_dim)
    in
    let effective_dropout = if training then config.dropout else 0.0 in
    let dropout_rng =
      if effective_dropout > 0.0 then
        match rngs with
        | Some key -> Some key
        | None -> failwith "attention dropout requires RNG"
      else None
    in
    let attention_mask = Option.map normalize_mask attention_mask in
    let q_projected = Rune.matmul query q_proj in
    let k_projected = Rune.matmul key k_proj in
    let v_projected = Rune.matmul value v_proj in
    let context =
      compute_attention_from_projected ?attention_mask
        ?dropout_rate:
          (if effective_dropout > 0.0 then Some effective_dropout else None)
        ?dropout_rng ~is_causal:false ~scale ~q:q_projected ~k:k_projected
        ~v:v_projected ~embed_dim:config.embed_dim ~num_heads:config.num_heads
        ~num_kv_heads ~head_dim ()
    in
    let output = Rune.matmul context out_proj in
    match config.attn_logits_soft_cap with
    | None -> output
    | Some cap ->
        let scaled = Rune.div output (Rune.scalar (Rune.dtype output) cap) in
        let capped = Rune.tanh scaled in
        Rune.mul capped (Rune.scalar (Rune.dtype output) cap)
end