package owl-base

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file owl_computation_type.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# 1 "src/base/compute/owl_computation_type.ml"
(*
 * OWL - OCaml Scientific and Engineering Computing
 * Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
 *)

open Owl_types

(* Functor of making the type of a computation graph. *)

module Make (Device : Owl_types_computation_device.Sig) = struct
  (* module constant, device-dependent types *)

  module Device = Device
  open Device

  (* type definitions *)

  type state =
    | Valid
    | Invalid

  type t = attr Owl_graph.node

  and block =
    { (* The [size] field assumes that all the elements have the same size. If
       * different types of elements are mixed in the same CG, should replace it
       * with a size in bytes. *)
      size : int
    ; (* the number of elements stored in the block *)
      block_id : int
    ; (* id of the block *)
      mutable active : t option
    ; (* the node whose memory is being stored (if any) *)
      mutable memory : value
    ; (* the placeholder for the value *)
      mutable nodes : t list (* the nodes sharing the memory block *)
    }

  and attr =
    { mutable op : op
    ; (* operation stored in this node *)
      mutable freeze : bool
    ; (* whether or not a node can link to other nodes *)
      mutable reuse : bool
    ; (* whether others can resuse the allocated memory *)
      mutable state : state
    ; (* state to show whether re-evaluation is needed *)
      mutable shape : int array option array
    ; (* shape of the output values stored in the node *)
      mutable value : value array
    ; (* output values of the node *)
      mutable block : block array option (* the memory blocks to store the node values *)
    }

  and arr = Arr of t

  and elt = Elt of t

  and op =
    | Noop
    | Var
    | Const
    | Empty                         of int array
    | Zeros                         of int array
    | Ones                          of int array
    | Create                        of int array
    | Sequential                    of int array
    | Uniform                       of int array
    | Gaussian                      of int array
    | Bernoulli                     of int array
    | Init                          of int array * (int -> elt)
    | Get                           of int array
    | Set                           of int array
    | GetSlice                      of int list list
    | SetSlice                      of int list list
    | Copy
    | Reset
    | Reshape                       of int array
    | Reverse
    | Tile                          of int array
    | Repeat                        of int array
    | Pad                           of elt * int list list
    | Concatenate                   of int
    | Stack                         of int
    | Split                         of int * int array
    | Draw                          of int * int
    | Map                           of (elt -> elt)
    | Fold                          of int * (elt -> elt -> elt)
    | Scan                          of int * (elt -> elt -> elt)
    | OneHot                        of int
    | OfArray                       of int array
    | Delay                         of (A.arr -> A.arr)
    | DelayArray                    of int array * (A.arr array -> A.arr)
    | LazyPrint                     of
        int option * int option * bool option * (A.elt -> string) option
    | Abs
    | Neg
    | Floor
    | Ceil
    | Round
    | Sqr
    | Sqrt
    | Log
    | Log2
    | Log10
    | Exp
    | Sin
    | Cos
    | Tan
    | Sinh
    | Cosh
    | Tanh
    | Asin
    | Acos
    | Atan
    | Asinh
    | Acosh
    | Atanh
    | Min                           of bool * int
    | Max                           of bool * int
    | Sum                           of bool * int
    | SumReduce                     of int array
    | Signum
    | Sigmoid
    | Relu
    | Dawsn
    | Min'
    | Max'
    | Sum'
    | LogSumExp'
    | LogSumExp                     of bool * int
    | L1norm'
    | L2norm'
    | L2NormSqr'
    | ClipByValue
    | ClipByL2norm
    | Pow
    | ScalarPow
    | PowScalar
    | Atan2
    | ScalarAtan2
    | Atan2Scalar
    | Hypot
    | Min2
    | Max2
    | Add
    | Sub
    | Mul
    | Div
    | AddScalar
    | SubScalar
    | MulScalar
    | DivScalar
    | ScalarAdd
    | ScalarSub
    | ScalarMul
    | ScalarDiv
    | FMA
    | EltEqual
    | EltNotEqual
    | EltLess
    | EltGreater
    | EltLessEqual
    | EltGreaterEqual
    | EltEqualScalar
    | EltNotEqualScalar
    | EltLessScalar
    | EltGreaterScalar
    | EltLessEqualScalar
    | EltGreaterEqualScalar
    | Conv1d                        of padding * int array
    | Conv2d                        of padding * int array
    | Conv3d                        of padding * int array
    | TransposeConv1d               of padding * int array
    | TransposeConv2d               of padding * int array
    | TransposeConv3d               of padding * int array
    | DilatedConv1d                 of padding * int array * int array
    | DilatedConv2d                 of padding * int array * int array
    | DilatedConv3d                 of padding * int array * int array
    | MaxPool1d                     of padding * int array * int array
    | MaxPool2d                     of padding * int array * int array
    | MaxPool3d                     of padding * int array * int array
    | AvgPool1d                     of padding * int array * int array
    | AvgPool2d                     of padding * int array * int array
    | AvgPool3d                     of padding * int array * int array
    | UpSampling2d                  of int array
    | Conv1dBackwardInput           of int array
    | Conv1dBackwardKernel          of int array
    | Conv2dBackwardInput           of int array
    | Conv2dBackwardKernel          of int array
    | Conv3dBackwardInput           of int array
    | Conv3dBackwardKernel          of int array
    | TransposeConv1dBackwardInput  of int array
    | TransposeConv1dBackwardKernel of int array
    | TransposeConv2dBackwardInput  of int array
    | TransposeConv2dBackwardKernel of int array
    | TransposeConv3dBackwardInput  of int array
    | TransposeConv3dBackwardKernel of int array
    | DilatedConv1dBackwardInput    of int array * int array
    | DilatedConv1dBackwardKernel   of int array * int array
    | DilatedConv2dBackwardInput    of int array * int array
    | DilatedConv2dBackwardKernel   of int array * int array
    | DilatedConv3dBackwardInput    of int array * int array
    | DilatedConv3dBackwardKernel   of int array * int array
    | MaxPool1dBackward             of padding * int array * int array
    | MaxPool2dBackward             of padding * int array * int array
    | MaxPool3dBackward             of padding * int array * int array
    | AvgPool1dBackward             of padding * int array * int array
    | AvgPool2dBackward             of padding * int array * int array
    | AvgPool3dBackward             of padding * int array * int array
    | UpSampling2dBackward          of int array
    | RowNum
    | ColNum
    | Row
    | Rows                          of int array
    | CopyRowTo
    | CopyColTo
    | Dot                           of bool * bool * elt * elt
    | Inv
    | Trace
    | Transpose                     of int array
    | ToRows
    | OfRows
    | Scalar_Add
    | Scalar_Sub
    | Scalar_Mul
    | Scalar_Div
    | Scalar_Pow
    | Scalar_Atan2
    | Scalar_Abs
    | Scalar_Neg
    | Scalar_Sqr
    | Scalar_Sqrt
    | Scalar_Exp
    | Scalar_Log
    | Scalar_Log2
    | Scalar_Log10
    | Scalar_Signum
    | Scalar_Floor
    | Scalar_Ceil
    | Scalar_Round
    | Scalar_Sin
    | Scalar_Cos
    | Scalar_Tan
    | Scalar_Sinh
    | Scalar_Cosh
    | Scalar_Tanh
    | Scalar_Asin
    | Scalar_Acos
    | Scalar_Atan
    | Scalar_Asinh
    | Scalar_Acosh
    | Scalar_Atanh
    | Scalar_Relu
    | Scalar_Dawsn
    | Scalar_Sigmoid
    | Fused_Adagrad                 of float * float
end

(* Make functor ends *)