Source file backend_intf.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
(** Backend interface.
The [`op_*`] functions mirror tinygrad's UOps. A backend can execute them
eagerly, raise effects for a JIT, build a computation graph, etc.
The frontend handles broadcasting and shape validation, so each operation
simply produces a fresh tensor. *)
module type S = sig
type ('a, 'b) t
(** Opaque tensor handle.
['a] is the OCaml element type; ['b] tags the dtype. *)
type context
(** Backend execution context. Carries any state required by the
implementation (memory pools, command queues, ...). *)
val view : ('a, 'b) t -> Lazy_view.t
(** Return the view tracker for [t]. *)
val dtype : ('a, 'b) t -> ('a, 'b) Dtype.t
(** Element type of [t]. *)
val context : ('a, 'b) t -> context
(** Execution context of [t]. *)
val data : ('a, 'b) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t
(** Return the raw buffer of [t]. *)
val op_buffer :
context -> ('a, 'b) Dtype.t -> int -> ('a, 'b) t
(** Allocate a buffer of [size_in_elements] elements of [dtype]. *)
val op_const_scalar : context -> 'a -> ('a, 'b) Dtype.t -> ('a, 'b) t
(** Tensor containing a single scalar [value]. *)
val op_const_array :
context -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t
(** Tensor containing the elements of [array]. The array must be contiguous.
*)
val op_add : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Element-wise addition. *)
val op_mul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Element-wise multiplication. *)
val op_idiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Integer division, truncating. *)
val op_fdiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Floating-point division. *)
val op_max : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Element-wise maximum. *)
val op_mod : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Integer modulus. *)
val op_pow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Raise [base] to [exponent]. *)
val op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (int, Dtype.uint8_elt) t
(** Compare [<]. Returns 0 or 1 as uint8. *)
val op_cmpne : ('a, 'b) t -> ('a, 'b) t -> (int, Dtype.uint8_elt) t
(** Compare [<>]. Returns 0 or 1 as uint8. *)
val op_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Bitwise XOR. *)
val op_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Bitwise OR. *)
val op_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Bitwise AND. *)
val op_neg : ('a, 'b) t -> ('a, 'b) t
(** Negation (logical not for bools). *)
val op_log2 : ('a, 'b) t -> ('a, 'b) t
(** Base-2 logarithm. *)
val op_exp2 : ('a, 'b) t -> ('a, 'b) t
(** Exponential base 2. *)
val op_sin : ('a, 'b) t -> ('a, 'b) t
(** Sine. *)
val op_sqrt : ('a, 'b) t -> ('a, 'b) t
(** Square root. *)
val op_recip : ('a, 'b) t -> ('a, 'b) t
(** Reciprocal. *)
val op_where :
(int, Dtype.uint8_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Select from [if_true] or [if_false] based on a boolean tensor. *)
val op_reduce_sum :
axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
(** Sum over [axes]. Keeps reduced dimensions if [keepdims] is true. *)
val op_reduce_max :
axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
(** Maximum over [axes]. Keeps reduced dimensions if [keepdims] is true. *)
val op_reduce_prod :
axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t
(** Product over [axes]. Keeps reduced dimensions if [keepdims] is true. *)
val op_associative_scan :
axis:int -> op:[ `Sum | `Prod | `Max | `Min ] -> ('a, 'b) t -> ('a, 'b) t
(** Inclusive scan along [axis] using the associative operation [op]. *)
val op_expand : ('a, 'b) t -> Symbolic_shape.t -> ('a, 'b) t
(** Broadcast dimensions of size 1 to a new shape. *)
val op_reshape : ('a, 'b) t -> Symbolic_shape.t -> ('a, 'b) t
(** Change the logical shape without moving data. *)
val op_permute : ('a, 'b) t -> int array -> ('a, 'b) t
(** Reorder dimensions according to [axes]. *)
val op_shrink : ('a, 'b) t -> (int * int) array -> ('a, 'b) t
(** Slice according to the given start/stop pairs. *)
val op_flip : ('a, 'b) t -> bool array -> ('a, 'b) t
(** Flip dimensions where the boolean array is [true]. *)
val op_pad : ('a, 'b) t -> (int * int) array -> 'a -> ('a, 'b) t
(** Pad with [fill_value] using the given configuration. *)
val op_cat : ('a, 'b) t list -> int -> ('a, 'b) t
(** Concatenate tensors along [axis]. *)
val op_cast : ('a, 'b) t -> ('c, 'd) Dtype.t -> ('c, 'd) t
(** Cast elements to [target_dtype]. *)
val op_contiguous : ('a, 'b) t -> ('a, 'b) t
(** Return a C-contiguous tensor. May copy. *)
val op_copy : ('a, 'b) t -> ('a, 'b) t
(** Duplicate [t]. Result has its own buffer. *)
val op_assign : ('a, 'b) t -> ('a, 'b) t -> unit
(** Store [src] into [dst] at the given logical indices. *)
val op_threefry :
(int32, Dtype.int32_elt) t ->
(int32, Dtype.int32_elt) t ->
(int32, Dtype.int32_elt) t
(** Threefry random number generator. *)
val op_gather :
('a, 'b) t ->
(int32, Dtype.int32_elt) t ->
int ->
('a, 'b) t
(** Gather elements from [data] along [axis] using [indices]. Output shape
matches [indices]. Ranks of [data] and [indices] must match. Sizes of
[indices] dims != [axis] must be <= [data] corresponding dims. *)
val op_scatter :
?mode:[ `Set | `Add ] ->
?unique_indices:bool ->
('a, 'b) t ->
(int32, Dtype.int32_elt) t ->
('a, 'b) t ->
int ->
('a, 'b) t
(** Scatter [updates] into a new tensor shaped like [data_template] along
[axis] using [indices]. Returns a new tensor.
- [mode] specifies how to handle duplicate indices:
- [`Set] (default): last update wins
- [`Add]: accumulate updates at duplicate indices
- [unique_indices]: hint that indices are unique (optimization) *)
val op_unfold :
('a, 'b) t ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t
(** Unfold (im2col) operation. Extracts sliding local blocks from a batched
input tensor. For an input of shape (N, C, *spatial_dims), produces output
of shape (N, C * prod(kernel_size), L) where L is the number of blocks.
Works for any number of spatial dimensions. *)
val op_fold :
('a, 'b) t ->
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t
(** Fold (col2im) operation. Combines an array of sliding local blocks into a
tensor. For an input of shape (N, C * prod(kernel_size), L), produces
output of shape (N, C, *output_size). Inverse of unfold. Overlapping
values are summed. Works for any number of spatial dimensions. *)
val op_matmul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
(** Matrix multiplication. For 2D tensors, computes standard matrix
multiplication. For higher dimensions, performs batched matrix
multiplication on the last two dimensions, broadcasting batch dimensions
as needed. The last dimension of the first tensor must match the
second-to-last dimension of the second tensor. *)
val op_fft : (Complex.t, 'b) t -> axes:int array -> (Complex.t, 'b) t
(** Compute the discrete Fourier transform (DFT) of the input tensor. *)
val op_ifft : (Complex.t, 'b) t -> axes:int array -> (Complex.t, 'b) t
(** Compute the inverse discrete Fourier transform (IDFT) of the input tensor.
*)
val op_rfft :
(float, 'a) t ->
dtype:(Complex.t, 'b) Dtype.t ->
axes:int array ->
(Complex.t, 'b) t
(** Compute the real-valued discrete Fourier transform (RDFT) of the input
tensor. *)
val op_irfft :
(Complex.t, 'a) t ->
dtype:(float, 'b) Dtype.t ->
axes:int array ->
s:int array option ->
(float, 'b) t
(** Compute the inverse real-valued discrete Fourier transform (IRDFT) of the
input tensor. *)
val op_cholesky : upper:bool -> ('a, 'b) t -> ('a, 'b) t
(** Cholesky decomposition of a positive-definite matrix.
- [upper]: If true, returns upper triangular factor; else lower (default).
- Input: Square matrix A (batched).
- Output: Triangular factor L or U such that A = L*L^T or A = U^T*U.
- Raises if input is not positive-definite. *)
val op_qr : reduced:bool -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t
(** QR decomposition.
- [reduced]: If true (default), returns economy/reduced QR; else full QR.
- Input: m x n matrix A (batched).
- Output: (Q, R) where A = Q*R, Q orthogonal, R upper triangular. *)
val op_svd :
full_matrices:bool ->
('a, 'b) t ->
('a, 'b) t * (float, Dtype.float64_elt) t * ('a, 'b) t
(** Singular value decomposition.
- [full_matrices]: If false (default), returns thin SVD; else full.
- Input: m x n matrix A (batched).
- Output: (U, S, V^H) where A = U*S*V^H.
- S is 1D vector of singular values in descending order, always float64.
*)
val op_eig :
vectors:bool ->
('a, 'b) t ->
(Complex.t, Dtype.complex64_elt) t
* (Complex.t, Dtype.complex64_elt) t option
(** General eigenvalue decomposition.
- [vectors]: If true (default), computes eigenvectors.
- Input: Square matrix A (batched).
- Output: (eigenvalues, optional eigenvectors) always as complex64. *)
val op_eigh :
vectors:bool ->
('a, 'b) t ->
(float, Dtype.float64_elt) t * ('a, 'b) t option
(** Symmetric/Hermitian eigenvalue decomposition.
- [vectors]: If true (default), computes eigenvectors.
- Input: Symmetric (real) or Hermitian (complex) matrix A (batched).
- Output: (eigenvalues as float64, eigenvectors same type as input). *)
val op_triangular_solve :
upper:bool ->
transpose:bool ->
unit_diag:bool ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t
(** Solve triangular system A*x = b or A^T*x = b.
- [upper]: If true, A is upper triangular; else lower.
- [transpose]: If true, solve A^T*x = b; else A*x = b.
- [unit_diag]: If true, assume diagonal of A is all 1s.
- Input: Triangular matrix A, right-hand side b (batched).
- Output: Solution x. *)
val op_as_strided :
('a, 'b) t -> Symbolic_shape.t -> int array -> int -> ('a, 'b) t
(** Create a strided view of the input tensor with the given shape, strides
(in elements), and offset (in elements). Backends that support arbitrary
strided views (e.g., native with Bigarray) can implement this as
zero-copy. Other backends may fall back to copying data if necessary.
Raises if the view would access out-of-bounds memory. *)
end