include Nx_core.Backend_intf.S
Types
'a is the OCaml element type (e.g., float, int32). 'b is a phantom type that tags the dtype for type safety.
Backend execution context.
Carries backend-specific state such as memory pools, device handles, command queues, or computation graphs.
Tensor Properties
view t returns the strided view metadata describing t's logical layout (shape, strides, offset) over its underlying buffer.
dtype t returns the element type of t.
val context : ('a, 'b) t -> contextcontext t returns the execution context that owns t.
to_host t returns t's data as a flat, C-contiguous host buffer.
Use view to interpret the logical structure. CPU backends may return a direct reference (zero-copy); GPU backends copy from device to host.
Tensor Creation
buffer ctx dtype shape allocates an uninitialized tensor.
Contents are undefined. Used internally by the frontend to pre-allocate ~out buffers before calling operations.
Backend must: return a tensor with the given shape and dtype whose view is C-contiguous.
full ctx dtype shape value creates a tensor where every element is value.
For scalars, shape is [||]. Subsumes zeros, ones, and constant fill.
Backend must: return a C-contiguous tensor of the given shape and dtype with all elements set to value.
from_host ctx buf creates a tensor from a flat, C-contiguous host buffer.
CPU backends may share the buffer directly (zero-copy). GPU backends copy from host to device.
Frontend guarantees: buf is C-contiguous.
Element-wise Binary Operations
Frontend guarantees: out, a, and b have identical shapes (after broadcasting) and compatible dtypes (after promotion). out is C-contiguous and pre-allocated with the correct shape.
Backend must: write exactly numel elements to out, respecting the strides of a and b (which may be non-contiguous or broadcast).
Arithmetic
val add : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitadd ~out a b computes out.{i} <- a.{i} + b.{i}.
val sub : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitsub ~out a b computes out.{i} <- a.{i} - b.{i}.
val mul : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitmul ~out a b computes out.{i} <- a.{i} * b.{i}.
val div : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitdiv ~out a b computes out.{i} <- a.{i} / b.{i}.
Integer dtypes use truncation toward zero (C division). Floating-point dtypes use IEEE 754 division.
val mod_ : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitmod_ ~out a b computes the remainder of a / b.
Integers use C's % operator (truncated division). Floats use fmod. The sign of the result follows the dividend a.
val pow : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitpow ~out base exponent computes out.{i} <- base.{i} ^ exponent.{i}.
val atan2 : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitatan2 ~out y x computes out.{i} <- atan2(y.{i}, x.{i}).
Returns the angle in radians in (-π, π], handling all quadrants.
Comparison
Comparison operations produce boolean tensors.
Frontend guarantees: out is a (bool, bool_elt) tensor with the same shape as a and b.
cmpeq ~out a b computes out.{i} <- (a.{i} = b.{i}).
cmpne ~out a b computes out.{i} <- (a.{i} <> b.{i}).
cmplt ~out a b computes out.{i} <- (a.{i} < b.{i}).
cmple ~out a b computes out.{i} <- (a.{i} <= b.{i}).
Min/Max
val max : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitmax ~out a b computes out.{i} <- max(a.{i}, b.{i}).
val min : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitmin ~out a b computes out.{i} <- min(a.{i}, b.{i}).
Bitwise
Operate on the binary representation of integer and boolean dtypes. For booleans, these are equivalent to logical AND/OR/XOR.
val xor : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitxor ~out a b computes bitwise XOR.
val or_ : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitor_ ~out a b computes bitwise OR.
val and_ : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitand_ ~out a b computes bitwise AND.
Element-wise Unary Operations
Frontend guarantees: out and x have the same shape and dtype. out is C-contiguous.
Backend must: write exactly numel elements to out, respecting the strides of x.
Arithmetic
val neg : out:('a, 'b) t -> ('a, 'b) t -> unitneg ~out x computes out.{i} <- -x.{i}.
val recip : out:('a, 'b) t -> ('a, 'b) t -> unitrecip ~out x computes out.{i} <- 1 / x.{i}.
val abs : out:('a, 'b) t -> ('a, 'b) t -> unitabs ~out x computes out.{i} <- |x.{i}|.
val sqrt : out:('a, 'b) t -> ('a, 'b) t -> unitsqrt ~out x computes out.{i} <- √x.{i}.
val sign : out:('a, 'b) t -> ('a, 'b) t -> unitsign ~out x computes the sign function: -1 for negative, 0 for zero, 1 for positive. Returns NaN for floating-point NaN inputs.
Exponential and Logarithm
val exp : out:('a, 'b) t -> ('a, 'b) t -> unitexp ~out x computes out.{i} <- eˣ⁽ⁱ⁾.
val log : out:('a, 'b) t -> ('a, 'b) t -> unitlog ~out x computes out.{i} <- ln(x.{i}).
Trigonometric
All inputs are in radians.
val sin : out:('a, 'b) t -> ('a, 'b) t -> unitsin ~out x computes out.{i} <- sin(x.{i}).
val cos : out:('a, 'b) t -> ('a, 'b) t -> unitcos ~out x computes out.{i} <- cos(x.{i}).
val tan : out:('a, 'b) t -> ('a, 'b) t -> unittan ~out x computes out.{i} <- tan(x.{i}).
val asin : out:('a, 'b) t -> ('a, 'b) t -> unitasin ~out x computes out.{i} <- arcsin(x.{i}).
Returns values in [-π/2, π/2].
val acos : out:('a, 'b) t -> ('a, 'b) t -> unitacos ~out x computes out.{i} <- arccos(x.{i}).
Returns values in [0, π].
val atan : out:('a, 'b) t -> ('a, 'b) t -> unitatan ~out x computes out.{i} <- arctan(x.{i}).
Returns values in [-π/2, π/2].
Hyperbolic
val sinh : out:('a, 'b) t -> ('a, 'b) t -> unitsinh ~out x computes out.{i} <- sinh(x.{i}).
val cosh : out:('a, 'b) t -> ('a, 'b) t -> unitcosh ~out x computes out.{i} <- cosh(x.{i}).
val tanh : out:('a, 'b) t -> ('a, 'b) t -> unittanh ~out x computes out.{i} <- tanh(x.{i}).
Rounding
For integer dtypes, all rounding operations are the identity.
val trunc : out:('a, 'b) t -> ('a, 'b) t -> unittrunc ~out x rounds toward zero.
val ceil : out:('a, 'b) t -> ('a, 'b) t -> unitceil ~out x rounds toward positive infinity.
val floor : out:('a, 'b) t -> ('a, 'b) t -> unitfloor ~out x rounds toward negative infinity.
val round : out:('a, 'b) t -> ('a, 'b) t -> unitround ~out x rounds to nearest integer, half away from zero (C's round).
Special Functions
val erf : out:('a, 'b) t -> ('a, 'b) t -> uniterf ~out x computes the error function erf(x) = 2/√π ∫₀ˣ e^(-t²) dt.
Ternary Operations
where ~out cond if_true if_false selects elements: if_true.{i} where cond.{i} is true, if_false.{i} otherwise.
Frontend guarantees: all four tensors have identical shapes. cond is boolean. out, if_true, if_false share the same dtype.
Reduction Operations
Reductions aggregate values along one or more axes.
Frontend guarantees: axes contains valid, non-negative, deduplicated axis indices. out is pre-allocated with the correct shape: reduced axes are either removed or kept as size-1 dimensions depending on keepdims.
val reduce_sum :
out:('a, 'b) t ->
axes:int array ->
keepdims:bool ->
('a, 'b) t ->
unitreduce_sum ~out ~axes ~keepdims x sums elements along axes.
val reduce_prod :
out:('a, 'b) t ->
axes:int array ->
keepdims:bool ->
('a, 'b) t ->
unitreduce_prod ~out ~axes ~keepdims x multiplies elements along axes.
val reduce_max :
out:('a, 'b) t ->
axes:int array ->
keepdims:bool ->
('a, 'b) t ->
unitreduce_max ~out ~axes ~keepdims x finds maximum along axes.
val reduce_min :
out:('a, 'b) t ->
axes:int array ->
keepdims:bool ->
('a, 'b) t ->
unitreduce_min ~out ~axes ~keepdims x finds minimum along axes.
argmax ~out ~axis ~keepdims x writes int32 indices of maximum values along axis to out. For ties, returns the first occurrence.
Frontend guarantees: axis is valid and non-negative. out has the correct reduced shape with int32 dtype.
argmin ~out ~axis ~keepdims x writes int32 indices of minimum values along axis to out. For ties, returns the first occurrence.
Frontend guarantees: axis is valid and non-negative. out has the correct reduced shape with int32 dtype.
val associative_scan :
out:('a, 'b) t ->
axis:int ->
op:[ `Sum | `Prod | `Max | `Min ] ->
('a, 'b) t ->
unitassociative_scan ~out ~axis ~op x computes an inclusive prefix scan along axis. `Sum for cumulative sum, `Prod for cumulative product, `Max/`Min for running max/min.
Frontend guarantees: axis is valid and non-negative. out has the same shape as x.
Sort Operations
Frontend guarantees: axis is valid and non-negative. out is pre-allocated with the correct shape and dtype.
val sort : out:('a, 'b) t -> axis:int -> descending:bool -> ('a, 'b) t -> unitsort ~out ~axis ~descending x sorts elements along axis. NaN values are placed at the end regardless of sort direction.
Frontend guarantees: out has the same shape and dtype as x.
argsort ~out ~axis ~descending x writes int32 indices that would sort elements along axis to out.
Frontend guarantees: out has the same shape as x with int32 dtype.
Movement Operations
Movement operations manipulate view metadata (shape, strides, offset) without copying data when possible. They return new tensor handles sharing the underlying buffer.
Frontend guarantees: all parameters are validated (axes in range, shapes compatible, bounds within limits).
Backend must: return a tensor with the correct view metadata. May share the underlying buffer (zero-copy) or allocate if necessary.
val expand : ('a, 'b) t -> int array -> ('a, 'b) texpand t shape broadcasts dimensions of size 1 to match shape by setting their stride to 0. Non-singleton dimensions must already match. Zero-copy.
val reshape : ('a, 'b) t -> int array -> ('a, 'b) treshape t shape changes the logical shape, preserving element count.
Zero-copy when t is C-contiguous or the reshape is compatible with the current strides. May copy if t is non-contiguous.
val permute : ('a, 'b) t -> int array -> ('a, 'b) tpermute t axes reorders dimensions according to axes, which must be a permutation of [0, ..., ndim-1]. Zero-copy.
val shrink : ('a, 'b) t -> (int * int) array -> ('a, 'b) tshrink t ranges extracts a contiguous slice. ranges.(i) is (start, stop) with exclusive stop. Zero-copy (adjusts offset and shape).
val flip : ('a, 'b) t -> bool array -> ('a, 'b) tflip t axes reverses dimensions where axes.(i) = true by negating strides. Zero-copy.
val pad : ('a, 'b) t -> (int * int) array -> 'a -> ('a, 'b) tpad t padding fill_value extends t with fill_value. padding.(i) is (before, after) for dimension i.
Backend must: allocate a new buffer and copy data.
val cat : out:('a, 'b) t -> ('a, 'b) t list -> axis:int -> unitcat ~out tensors ~axis concatenates tensors along axis into out.
Frontend guarantees: all tensors have the same shape except along axis. axis is valid. The list is non-empty. out is pre-allocated with the correct concatenated shape.
Type Conversion and Memory
val cast : out:('c, 'd) t -> ('a, 'b) t -> unitcast ~out x converts elements of x to the dtype of out.
Float-to-int truncates toward zero. Int-to-float may lose precision for large values.
Frontend guarantees: out is pre-allocated with the correct shape and target dtype.
val contiguous : ('a, 'b) t -> ('a, 'b) tcontiguous t returns a C-contiguous version of t.
May return t unchanged if already C-contiguous. Otherwise allocates and copies.
Backend must: return a C-contiguous tensor with the same data.
val copy : ('a, 'b) t -> ('a, 'b) tcopy t creates an independent copy with its own buffer.
Backend must: always allocate a new buffer, even if t is already contiguous.
val assign : ('a, 'b) t -> ('a, 'b) t -> unitassign dst src copies elements from src into dst in-place.
Frontend guarantees: dst and src have matching shapes and dtypes.
Backend must: write src's data into dst's buffer, respecting both tensors' strides.
Random Number Generation
threefry ~out key counter applies the Threefry-2x32 hash function.
Frontend guarantees: key and counter are int32 tensors with compatible shapes. out is pre-allocated with the same shape as counter.
Indexed Access Operations
gather ~out data indices ~axis selects elements from data along axis using indices and writes them to out.
Frontend guarantees: rank data = rank indices. axis is valid. Index values are in range for data's size along axis. out has the same shape as indices and the same dtype as data.
val scatter :
?mode:[ `Set | `Add ] ->
?unique_indices:bool ->
('a, 'b) t ->
indices:(int32, Nx_core.Dtype.int32_elt) t ->
updates:('a, 'b) t ->
axis:int ->
('a, 'b) tscatter ?mode ?unique_indices template ~indices ~updates ~axis places updates into a tensor shaped like template along axis.
`Set (default) uses the last update for duplicate indices. `Add accumulates. unique_indices = true hints that indices are unique.
Frontend guarantees: rank indices = rank updates. axis is valid. template has the desired output shape.
Backend must: allocate and return the result tensor, initialized from template's data.
Window Operations
Sliding-window extraction and its inverse. Used to implement convolution as unfold + reshape + matmul and pooling as unfold + reduce.
val unfold :
('a, 'b) t ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) tunfold t ~kernel_size ~stride ~dilation ~padding extracts sliding windows from the last K spatial dimensions, where K = Array.length kernel_size.
Input shape (leading..., spatial...) produces (leading..., prod(kernel_size), L) where L is the number of windows. All dimensions before the last K are preserved as-is.
Frontend guarantees: all array parameters have length K. Values are positive. Input has at least K dimensions.
Backend must: allocate and return the result tensor.
val fold :
('a, 'b) t ->
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) tfold t ~output_size ~kernel_size ~stride ~dilation ~padding combines sliding windows (inverse of unfold). Overlapping values are summed.
Input shape (leading..., prod(kernel_size), L) produces (leading..., output_size...).
Frontend guarantees: parameters are consistent with a valid unfold configuration.
Backend must: allocate and return the result tensor.
Matrix Operations
val matmul : out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> unitmatmul ~out a b computes matrix multiplication a × b.
For 2D inputs: standard matrix multiply. For higher dimensions: batched multiply on the last two dimensions, with broadcasting via strides.
Frontend guarantees: a's last dim equals b's second-to-last dim. out is C-contiguous with the correct output shape.
Backend must: write the result to out. May use BLAS for performance. a and b may be non-contiguous.
Frontend guarantees: axes contains valid, non-negative axis indices. Input tensors have compatible complex or real dtypes.
fft ?out t ~axes computes the forward DFT along axes.
ifft ?out t ~axes computes the inverse DFT along axes.
rfft ?out t ~dtype ~axes computes the real-input DFT along axes.
Exploits conjugate symmetry to return only the non-redundant half of the spectrum along the last transformed axis.
val irfft :
?out:(float, 'b) t ->
?s:int array ->
(Complex.t, 'a) t ->
dtype:(float, 'b) Nx_core.Dtype.t ->
axes:int array ->
(float, 'b) tirfft ?out ?s t ~dtype ~axes computes the inverse real-input DFT along axes.
Takes conjugate-symmetric complex input, returns real output. s specifies output sizes along the transformed axes; None infers sizes from the input.
Linear Algebra
All linalg operations support batching: the last two dimensions are the matrix dimensions, earlier dimensions are batch dimensions.
Frontend guarantees: input matrices have compatible shapes (square where required, matching dimensions for solves).
Backend must: allocate and return result tensors. Typically delegates to LAPACK.
val cholesky : upper:bool -> ('a, 'b) t -> ('a, 'b) tcholesky ~upper t computes the Cholesky factorization of a positive-definite matrix. Returns L (lower) or U (upper) such that A = L·Lᵀ or A = Uᵀ·U.
val qr : reduced:bool -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) tqr ~reduced t returns (Q, R) where Q is orthogonal and R is upper triangular. reduced = true returns economy-size factorization.
svd ~full_matrices t returns (U, S, Vᴴ). S is a 1D float64 vector of singular values in descending order. full_matrices = false returns thin SVD.
eig ~vectors t computes eigenvalues (and optionally eigenvectors) of a square matrix. Returns complex64 results.
eigh ~vectors t computes eigenvalues (and optionally eigenvectors) of a symmetric/Hermitian matrix. Eigenvalues are float64.
val triangular_solve :
upper:bool ->
transpose:bool ->
unit_diag:bool ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) ttriangular_solve ~upper ~transpose ~unit_diag a b solves A·x = b or Aᵀ·x = b where A is triangular.
upper: A is upper triangular. transpose: solve Aᵀ·x = b. unit_diag: assume diagonal is all ones.