package arrayjit

  1. Overview
  2. Docs

Source file ops.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
open Base
(** Operation types shared by all backends; and precision types. *)

module Lazy = Utils.Lazy

(** {2 *** Precision ***} *)

type uint8_elt = Bigarray.int8_unsigned_elt

(* FIXME: Upcoming in OCaml 5.2.0. See:
   https://github.com/ocaml/ocaml/pull/10775/commits/ba6a2c378056c8669fb1bb99bf07b12d69bd4a12 *)
type float16_elt = Bigarray.float32_elt
type float32_elt = Bigarray.float32_elt
type float64_elt = Bigarray.float64_elt

let float16 : (float, float16_elt) Bigarray.kind = Bigarray.float32

type ('ocaml, 'impl) precision =
  | Byte : (char, uint8_elt) precision
  | Half : (float, float16_elt) precision
  | Single : (float, float32_elt) precision
  | Double : (float, float64_elt) precision
[@@deriving sexp_of]

type prec =
  | Void_prec
  | Byte_prec of (char, uint8_elt) precision
  | Half_prec of (float, float16_elt) precision
  | Single_prec of (float, float32_elt) precision
  | Double_prec of (float, float64_elt) precision

let byte = Byte_prec Byte
let half = Half_prec Half
let single = Single_prec Single
let double = Double_prec Double

let sexp_of_prec = function
  | Void_prec -> Sexp.Atom "Void_prec"
  | Byte_prec _ -> Sexp.Atom "Byte_prec"
  | Half_prec _ -> Sexp.Atom "Half_prec"
  | Single_prec _ -> Sexp.Atom "Single_prec"
  | Double_prec _ -> Sexp.Atom "Double_prec"

let prec_of_sexp = function
  | Sexp.Atom "Void_prec" -> Void_prec
  | Sexp.Atom "Byte_prec" -> byte
  | Sexp.Atom "Half_prec" -> half
  | Sexp.Atom "Single_prec" -> single
  | Sexp.Atom "Double_prec" -> double
  | Sexp.List _ -> invalid_arg "prec_of_sexp: expected atom, found list"
  | Sexp.Atom s -> invalid_arg @@ "prec_of_sexp: unknown precision " ^ s

let precision_to_string (type ocaml elt_t) (prec : (ocaml, elt_t) precision) =
  match prec with Byte -> "byte" | Half -> "half" | Single -> "single" | Double -> "double"

let prec_string = function
  | Void_prec -> "void"
  | Byte_prec _ -> "byte"
  | Half_prec _ -> "half"
  | Single_prec _ -> "single"
  | Double_prec _ -> "double"

let equal_prec p1 p2 =
  match (p1, p2) with
  | Void_prec, Void_prec -> true
  | Byte_prec _, Byte_prec _ -> true
  | Half_prec _, Half_prec _ -> true
  | Single_prec _, Single_prec _ -> true
  | Double_prec _, Double_prec _ -> true
  | Void_prec, _ | Byte_prec _, _ | Half_prec _, _ | Single_prec _, _ | Double_prec _, _ -> false

let prec_in_bytes = function
  | Void_prec -> 0
  | Byte_prec _ -> 2
  | Half_prec _ -> 2
  | Single_prec _ -> 4
  | Double_prec _ -> 8

let promote_prec p1 p2 =
  match (p1, p2) with
  | Double_prec _, _ -> p1
  | _, Double_prec _ -> p2
  | Single_prec _, _ -> p1
  | _, Single_prec _ -> p2
  | Half_prec _, _ -> p1
  | _, Half_prec _ -> p2
  | Byte_prec _, _ -> p1
  | _, Byte_prec _ -> p2
  | Void_prec, Void_prec -> Void_prec

let is_double (type ocaml elt_t) (prec : (ocaml, elt_t) precision) =
  match prec with Double -> true | _ -> false

let is_double_prec = function Double_prec _ -> true | _ -> false

let pack_prec (type ocaml elt_t) (prec : (ocaml, elt_t) precision) =
  match prec with Byte -> byte | Half -> half | Single -> single | Double -> double

type 'r map_prec = { f : 'ocaml 'elt_t. ('ocaml, 'elt_t) precision -> 'r }

let map_prec ?default { f } = function
  | Void_prec -> Option.value_or_thunk default ~default:(fun () -> invalid_arg "map_prec: Void_prec")
  | Byte_prec Byte -> f Byte
  | Half_prec (Half | Single) -> f Half
  | Single_prec (Single | Half) -> f Single
  | Double_prec Double -> f Double
  | _ -> .

let c_typ_of_prec =
  let open Gccjit in
  function
  | Byte_prec _ -> Type.Unsigned_char
  | Half_prec _ -> (* FIXME: *) Type.Float
  | Single_prec _ -> Type.Float
  | Double_prec _ -> Type.Double
  | Void_prec -> Type.Void

(** {2 *** Operations ***} *)

(** Initializes or resets a array by filling in the corresponding numbers, at the appropriate precision. *)
type init_op =
  | Constant_fill of { values : float array; strict : bool }
      (** Fills in the numbers where the rightmost axis is contiguous. If [strict=true], loops over the
          provided values. *)
  | Range_over_offsets
      (** Fills in the offset number of each cell (i.e. how many cells away it is from the beginning). *)
  | Standard_uniform  (** Draws the values from U(0,1). *)
  | File_mapped of string * prec  (** Reads the data using [Unix.openfile] and [Unix.map_file]. *)
[@@deriving equal, sexp]

type binop = Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Arg1 [@@deriving sexp, compare, equal]
type unop = Identity | Relu [@@deriving sexp, compare, equal]

(** Either the left-neutral or right-neutral element of the operation. Unspecified if the operation does not
    have a neutral element. *)
let neutral_elem = function
  | Add | Sub -> 0.
  | Mul | Div -> 1.
  | ToPowOf -> 1.
  | Relu_gate -> 1.
  | Arg2 -> 0.
  | Arg1 -> 0.

let interpret_binop op v1 v2 =
  let open Float in
  match op with
  | Arg1 -> v1
  | Arg2 -> v2
  | Add -> v1 + v2
  | Sub -> v1 - v2
  | Mul -> v1 * v2
  | Div -> v1 / v2
  | ToPowOf -> if is_integer v2 then int_pow v1 @@ to_int v2 else v1 ** v2
  | Relu_gate -> if v1 > 0.0 then v2 else 0.0

let interpret_unop op v =
  let open Float in
  match op with Identity -> v | Relu when v >= 0. -> v | Relu -> 0.

let binop_C_syntax ~is_double = function
  | Arg1 -> invalid_arg "Ops.binop_C_syntax: Arg1 is not a C operator"
  | Arg2 -> invalid_arg "Ops.binop_C_syntax: Arg2 is not a C operator"
  | Add -> ("(", " +", ")")
  | Sub -> ("(", " -", ")")
  | Mul -> ("(", " *", ")")
  | Div -> ("(", " /", ")")
  | ToPowOf when is_double -> ("pow(", ",", ")")
  | ToPowOf -> ("powf(", ",", ")")
  | Relu_gate -> ("(", " > 0.0 ?", " : 0.0)")
(* "((int)(", "> 0.0) *", ")" *)

let binop_cd_syntax = function
  | Arg1 -> "-@>"
  | Arg2 -> "-/>"
  | Add -> "+"
  | Sub -> "-"
  | Mul -> "*"
  | Div -> "/"
  | ToPowOf -> "**"
  | Relu_gate -> "-?/"

let assign_op_C_syntax = function
  | Arg1 -> invalid_arg "Ops.assign_op_C_syntax: Arg1 is not a C assignment operator"
  | Arg2 -> "="
  | Add -> "+="
  | Sub -> "-="
  | Mul -> "*="
  | Div -> "/="
  | ToPowOf -> invalid_arg "Ops.assign_op_C_syntax: ToPowOf function is not a C assignment operator"
  | Relu_gate -> invalid_arg "Ops.assign_op_C_syntax: Relu_gate is not a C assignment operator"

let assign_op_cd_syntax ~initialize_neutral = function
  | Arg1 -> invalid_arg "Ops.assign_op_cd_syntax: Arg1 is not a %cd assignment operator"
  | Arg2 -> "=:"
  | Add when initialize_neutral -> "=:+"
  | Sub when initialize_neutral -> "=:-"
  | Mul when initialize_neutral -> "=:*"
  | Div when initialize_neutral -> "=:/"
  | ToPowOf when initialize_neutral -> "=:**"
  | Relu_gate when initialize_neutral -> "=:?/"
  | Add -> "=+"
  | Sub -> "=-"
  | Mul -> "=*"
  | Div -> "=/"
  | ToPowOf -> "=**"
  | Relu_gate -> "=?/"

let unop_cd_syntax = function Identity -> "~=" | Relu -> "?/"

(** {2 *** Global references ***} *)

type voidptr = unit Ctypes.ptr

let sexp_of_voidptr p = Sexp.Atom Ctypes.(string_of (ptr void) p)
let compare_voidptr = Ctypes.ptr_compare
let equal_voidptr : voidptr -> voidptr -> bool = phys_equal

let ptr_to_string ptr prec =
  let open Gccjit in
  let ctx = Context.create () in
  let result = RValue.to_string @@ RValue.ptr ctx Type.(pointer @@ get ctx @@ c_typ_of_prec prec) ptr in
  Context.release ctx;
  result

type global_identifier =
  | C_function of string  (** Calls a no-argument or indices-arguments C function. *)
  | External_unsafe of {
      ptr : voidptr;
      prec : (prec[@equal.ignore] [@compare.ignore]);
      dims : int array Lazy.t;
    }
[@@deriving sexp_of, equal, compare]