package neural_nets_lib

  1. Overview
  2. Docs

Source file operation.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
(** Computational primitives for neural networks, integrating [Tensor] with [Assignments]. *)

open Base
module Asgns = Arrayjit.Assignments
module Idx = Arrayjit.Indexing

module At = struct
  (** Get the value at the given indices. *)
  let ( .@{} ) = Tensor.get_value

  let ( .@%{} ) = Tensor.get_grad

  (** Set the value at the given indices. *)
  let ( .@{}<- ) = Tensor.set_value

  let ( .@%{}<- ) = Tensor.set_grad

  (** Get the value at the given index from a single-axis shape tensor. *)
  let ( .@[] ) t indx = Tensor.get_value t [| indx |]

  let ( .@%[] ) t indx = Tensor.get_grad t [| indx |]

  (** Set the value at the given index for a single-axis shape tensor. *)
  let ( .@[]<- ) t indx = Tensor.set_value t [| indx |]

  let ( .@%[]<- ) t indx = Tensor.set_grad t [| indx |]
end

module Initial_NTDSL = struct
  let term = Tensor.term ~grad_spec:Prohibit_grad
  let number = Tensor.number ~grad_spec:Prohibit_grad
  let ndarray = Tensor.ndarray ~grad_spec:Prohibit_grad

  module O = struct end
end

module Initial_TDSL = struct
  let term = Tensor.term ~grad_spec:If_needed
  let number = Tensor.number ~grad_spec:If_needed
  let ndarray = Tensor.ndarray ~grad_spec:If_needed
  let param = Tensor.param

  module O = struct end
end

let add ?(label = []) =
  let module NTDSL = Initial_NTDSL in
  let%cd op_asn ~v ~t1 ~t2 ~projections = v =: v1 + v2 in
  let%cd grad_asn ~v:_ ~g ~t1 ~t2 ~projections =
    g1 =+ g;
    g2 =+ g
  in
  Tensor.binop ~label:("+" :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn

let sub ?(label = []) =
  let module NTDSL = Initial_NTDSL in
  let%cd op_asn ~v ~t1 ~t2 ~projections = v =: v1 - v2 in
  let%cd grad_asn ~v:_ ~g ~t1 ~t2 ~projections =
    g1 =+ g;
    g2 =- g
  in
  Tensor.binop ~label:("-" :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn

let mul compose_op ~op_asn =
  let module NTDSL = Initial_NTDSL in
  let%cd grad_asn ~v:_ ~g ~t1 ~t2 ~projections =
    g1 =+ g * v2;
    g2 =+ v1 * g
  in
  Tensor.binop ~compose_op ~op_asn ~grad_asn

let pointmul ?(label = []) =
  let module NTDSL = Initial_NTDSL in
  let%cd op_asn ~v ~t1 ~t2 ~projections = v =: v1 * v2 in
  mul Pointwise_bin ~op_asn ~label:("*." :: label)

(* N1: AxB, N2 BxC, v: AxC, A: output of N1, B: input/output of N1/N2, C: input of N2. Although the
   matrix algebra would require that we insert additional transposes in gradient multiplies: AxB =
   AxC * CxB = AxC * (BxC)^T -> N1g += Ng * N2v^T, BxC = BxA * AxC = (AxB)^T * AxC -> N2g += N1v^T *
   Ng, in our setup there is no transposing to do, since the projections produce correct indices for
   their corresponding matrices. *)

let matmul ?(label = []) =
  let module NTDSL = Initial_NTDSL in
  let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 * v2 in
  mul Compose ~op_asn ~label:("*" :: label)

(** Similar to the explicit mode of [numpy.einsum], the binary variant. Can compute various forms of
    matrix multiplication, inner and outer products, etc.

    Note that ["a,b->c"] from [numpy] is ["a;b=>c"] in OCANNL, since ["->"] is used to separate the
    input and the output axes. *)
let einsum ?(label = []) spec =
  let module NTDSL = Initial_NTDSL in
  let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 * v2 in
  let%cd grad_asn ~v:_ ~g ~t1 ~t2 ~projections =
    g1 =+ g * v2;
    g2 =+ v1 * g
  in
  Tensor.binop ~label:(";=>" :: label) ~compose_op:(Einsum spec) ~op_asn ~grad_asn

(** Like [einsum], but adds instead than multiplying the resulting values. *)
let outer_sum ?(label = []) spec =
  let module NTDSL = Initial_NTDSL in
  let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 + v2 in
  let%cd grad_asn ~v:_ ~g ~t1 ~t2 ~projections =
    g1 =+ g;
    g2 =+ g
  in
  Tensor.binop ~label:(";=>+" :: label) ~compose_op:(Einsum spec) ~op_asn ~grad_asn

(** Similar to the explicit mode of [numpy.einsum], the unary variant. Can permute axes, extract
    diagonals, compute traces etc.

    Note that ["a->c"] from [numpy] is ["a=>c"] in OCANNL, since ["->"] is used to separate the
    input and the output axes. *)
let einsum1 ?(label = []) spec =
  let module NTDSL = Initial_NTDSL in
  let%cd op_asn ~v ~t1 ~projections = v =:+ v1 in
  let%cd grad_asn ~v:_ ~g ~t1 ~projections = g1 =+ g in
  Tensor.unop ~label:("=>" :: label) ~transpose_op:(Shape.Permute spec) ~op_asn ~grad_asn

let relu ?(label = []) =
  let module NTDSL = Initial_NTDSL in
  let%cd op_asn ~v ~t1 ~projections = v =: ?/v1 ~projections in
  let%cd grad_asn ~v ~g ~t1 ~projections = g1 =+ v -?/ g in
  Tensor.unop ~label:("?/" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn

module NDO_without_pow = struct
  let ( * ) = matmul ~grad_spec:Prohibit_grad
  let ( *. ) = pointmul ~grad_spec:Prohibit_grad
  let ( + ) = add ~grad_spec:Prohibit_grad
  let ( ?/ ) = relu ~grad_spec:Prohibit_grad
  let ( !. ) = Tensor.number ~grad_spec:Prohibit_grad
  let ( !.. ) ?label i = Tensor.number ?label ~grad_spec:Prohibit_grad @@ Float.of_int i
  let ( - ) = sub ~grad_spec:Prohibit_grad
  let ( ~- ) ?label t = ( *. ) ?label !.(-1.) t
end

let rec pointpow ?(label : string list = []) ~grad_spec p t1 : Tensor.t =
  let module NTDSL = struct
    include Initial_NTDSL

    module O = struct
      include NDO_without_pow

      let ( **. ) ?label base exp = pointpow ?label ~grad_spec:Tensor.Prohibit_grad exp base
    end
  end in
  let p_t = NTDSL.number p in
  let%cd op_asn ~v ~t1 ~t2 ~projections = v =: v1 ** v2 ~projections in
  let%cd grad_asn =
    if Tensor.is_prohibit_grad grad_spec then fun ~v:_ ~g:_ ~t1:_ ~t2:_ ~projections:_ -> Asgns.Noop
    else if Float.equal p 2.0 then fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ p_t *. t1 * g
    else if Float.equal p 1.0 then fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ g
    else fun ~v:_ ~g ~t1 ~t2:_ ~projections -> g1 =+ p_t *. (t1 **. (p -. 1.)) * g
  in
  Tensor.binop ~label:("**." :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn ~grad_spec t1 p_t

module NDO_without_div = struct
  include NDO_without_pow

  let ( **. ) ?label base exp = pointpow ?label ~grad_spec:Tensor.Prohibit_grad exp base
end

let rec pointdiv ?(label : string list = []) ~grad_spec t1 t2 =
  let module NTDSL = struct
    include Initial_NTDSL

    module O = struct
      include NDO_without_div

      let ( /. ) = pointdiv ~grad_spec:Tensor.Prohibit_grad
    end
  end in
  let%cd op_asn ~v ~t1 ~t2 ~projections = v =: v1 / v2 in
  (* We cannot use g in a tensor expression since it's an array, so we keep it to the left
     (RHS1). *)
  let%cd grad_asn ~v:_ ~g ~t1 ~t2 ~projections =
    g1 =+ g / v2;
    g2 =+ g * (-1 *. t1 /. (t2 **. 2))
  in
  Tensor.binop ~label:("/." :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn ~grad_spec t1 t2

let range ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?axis_label upto =
  let result =
    Tensor.term
      ~label:(("0" ^ "..." ^ Int.to_string upto) :: label)
      ~grad_spec ~batch_dims:[] ~input_dims:[] ~init_op:Range_over_offsets
  in
  match axis_label with
  | None -> result ~output_dims:[ upto + 1 ] ()
  | Some l -> result ~output_axes:[ (l, upto + 1) ] ()

let range_of_shape ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?batch_dims ?input_dims
    ?output_dims ?batch_axes ?input_axes ?output_axes () =
  let f (dims, axes) =
    Array.of_list @@ Option.value ~default:[] @@ Option.first_some dims
    @@ Option.map axes ~f:(List.map ~f:snd)
  in
  let dims =
    Array.concat_map ~f
      [| (batch_dims, batch_axes); (output_dims, output_axes); (input_dims, input_axes) |]
  in
  let batch_dims = Option.first_some batch_dims @@ Option.some_if (Option.is_none batch_axes) [] in
  let input_dims = Option.first_some input_dims @@ Option.some_if (Option.is_none input_axes) [] in
  let output_dims =
    Option.first_some output_dims @@ Option.some_if (Option.is_none output_axes) []
  in
  Tensor.term
    ~label:(("r" ^ Idx.dims_to_string dims) :: label)
    ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes
    ~init_op:Range_over_offsets ()

(** A [stop_gradient] is an identity in the forward pass and a no-op in the backprop pass. *)
let stop_gradient ?(label = []) =
  let module NTDSL = Initial_NTDSL in
  let grad_asn ~v:_ ~g:_ ~t1:_ ~projections:_ = Asgns.Noop in
  let%cd op_asn ~v ~t1 ~projections = v =: v1 in
  Tensor.unop ~label:("stop_grad" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn
    ~grad_spec:Prohibit_grad

let slice ?(label = []) ~grad_spec (batch_idx : Idx.static_symbol) t1 : Tensor.t =
  let module NTDSL = Initial_NTDSL in
  let op_asn ~v ~t1 ~projections =
    Asgns.Fetch
      {
        array = v;
        fetch_op = Slice { batch_idx; sliced = t1.Tensor.value };
        dims = lazy (Lazy.force projections).Idx.lhs_dims;
      }
  in
  let%cd grad_asn ~v:_ ~g ~t1 ~projections = g1 =+ g in
  Tensor.unop ~label:("@|" :: label) ~transpose_op:(Batch_slice batch_idx) ~op_asn ~grad_asn
    ~grad_spec t1

let embed_symbol ?(label = []) static_sym : Tensor.t =
  let module NTDSL = Initial_NTDSL in
  let op_asn ~v ~projections =
    Asgns.Fetch
      {
        array = v;
        fetch_op = Embed_symbol static_sym;
        dims = lazy (Lazy.force projections).Idx.lhs_dims;
      }
  in
  let grad_asn ~v:_ ~g:_ ~projections:_ = Asgns.Noop in
  Tensor.op ~label:("!@" :: label) ~op_asn ~grad_asn ~grad_spec:Prohibit_grad
    (Shape.make ~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ())
    []

module DO = struct
  let ( * ) = matmul ~grad_spec:If_needed
  let ( *. ) = pointmul ~grad_spec:If_needed
  let ( + ) = add ~grad_spec:If_needed
  let ( **. ) ?label base exp = pointpow ?label exp base ~grad_spec:If_needed
  let ( ?/ ) = relu ~grad_spec:If_needed
  let ( !~ ) label = Tensor.param label
  let ( !. ) = Tensor.number ~grad_spec:If_needed
  let ( !.. ) ?label i = Tensor.number ?label ~grad_spec:If_needed @@ Float.of_int i
  let ( !@ ) = embed_symbol
  let ( - ) = sub ~grad_spec:If_needed
  let ( ~- ) ?label t = ( *. ) ?label !.(-1.) t
  let ( /. ) = pointdiv ~grad_spec:If_needed
  let ( @| ) ?label t1 idx = slice ?label ~grad_spec:If_needed idx t1
end

module NDO = struct
  include NDO_without_div

  let ( /. ) = pointdiv ~grad_spec:Prohibit_grad
  let ( @| ) ?label t1 idx = slice ?label ~grad_spec:Prohibit_grad idx t1
end

module TDSL = struct
  include Initial_TDSL
  module O = DO

  let einsum = einsum ~grad_spec:If_needed
  let outer_sum = outer_sum ~grad_spec:If_needed
  let einsum1 = einsum1 ~grad_spec:If_needed
  let range = range ~grad_spec:If_needed
  let range_of_shape = range_of_shape ~grad_spec:If_needed
  let stop_gradient = stop_gradient

  (** The input [i] dimensions default to empty. The batch dimensions will be inferred if omitted.
      [strict] controls whether [Constant_fill] will try to fit the given values in the tensor and
      contribute to shape inference. If it is not provided explicitly, it will be [true] if [b] is
      omitted, and [false] otherwise. *)
  let init_const ~l ?strict ?b ?(i = []) ~o values =
    let strict =
      match (strict, b) with Some s, _ -> s | None, Some _ -> false | None, None -> true
    in
    Tensor.term ~label:[ l ] ~grad_spec:Prohibit_grad ?batch_dims:b ~input_dims:i ~output_dims:o
      ~init_op:(Constant_fill { values; strict })
      ()

  (** It's like `Tensor.param` but without shape inference. *)
  let init_param ~l ?(b = []) ?(i = []) ?(o = []) values =
    Tensor.term ~label:[ l ] ~grad_spec:Require_grad ~batch_dims:b ~input_dims:i ~output_dims:o
      ~init_op:(Constant_fill { values; strict = false })
      ()
end

module NTDSL = struct
  include Initial_NTDSL
  module O = NDO

  let einsum = einsum ~grad_spec:Prohibit_grad
  let outer_sum = outer_sum ~grad_spec:Prohibit_grad
  let einsum1 = einsum1 ~grad_spec:Prohibit_grad
  let term = Tensor.term ~grad_spec:Prohibit_grad
  let range = range ~grad_spec:Prohibit_grad
  let range_of_shape = range_of_shape ~grad_spec:Prohibit_grad

  let counter ?(label = []) =
    let module NTDSL = Initial_NTDSL in
    let%cd op_asn ~v ~t1 ~projections = v =+ t1 ~projections in
    let grad_asn ~v:_ ~g:_ ~t1:_ ~projections:_ = Asgns.Noop in
    Tensor.unop ~label:("counter" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn
      ~grad_spec:Prohibit_grad
end