package nx

  1. Overview
  2. Docs

Module Nx_cSource

include Nx_core.Backend_intf.S
Sourcetype ('a, 'b) t

Opaque tensor handle.

'a is the OCaml element type; 'b tags the dtype.

Sourcetype context

Backend execution context. Carries any state required by the implementation (memory pools, command queues, ...).

Sourceval view : ('a, 'b) t -> Nx_core.Lazy_view.t

Return the view tracker for t.

Sourceval dtype : ('a, 'b) t -> ('a, 'b) Nx_core.Dtype.t

Element type of t.

Sourceval context : ('a, 'b) t -> context

Execution context of t.

Sourceval data : ('a, 'b) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t

Return the raw buffer of t.

Sourceval op_buffer : context -> ('a, 'b) Nx_core.Dtype.t -> int -> ('a, 'b) t

Allocate a buffer of size_in_elements elements of dtype.

Sourceval op_const_scalar : context -> 'a -> ('a, 'b) Nx_core.Dtype.t -> ('a, 'b) t

Tensor containing a single scalar value.

Sourceval op_const_array : context -> ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t -> ('a, 'b) t

Tensor containing the elements of array. The array must be contiguous.

Sourceval op_add : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Element-wise addition.

Sourceval op_mul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Element-wise multiplication.

Sourceval op_idiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Integer division, truncating.

Sourceval op_fdiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Floating-point division.

Sourceval op_max : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Element-wise maximum.

Sourceval op_mod : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Integer modulus.

Sourceval op_pow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Raise base to exponent.

Sourceval op_cmplt : ('a, 'b) t -> ('a, 'b) t -> (int, Nx_core.Dtype.uint8_elt) t

Compare <. Returns 0 or 1 as uint8.

Sourceval op_cmpne : ('a, 'b) t -> ('a, 'b) t -> (int, Nx_core.Dtype.uint8_elt) t

Compare <>. Returns 0 or 1 as uint8.

Sourceval op_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Bitwise XOR.

Sourceval op_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Bitwise OR.

Sourceval op_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Bitwise AND.

Sourceval op_neg : ('a, 'b) t -> ('a, 'b) t

Negation (logical not for bools).

Sourceval op_log2 : ('a, 'b) t -> ('a, 'b) t

Base-2 logarithm.

Sourceval op_exp2 : ('a, 'b) t -> ('a, 'b) t

Exponential base 2.

Sourceval op_sin : ('a, 'b) t -> ('a, 'b) t

Sine.

Sourceval op_sqrt : ('a, 'b) t -> ('a, 'b) t

Square root.

Sourceval op_recip : ('a, 'b) t -> ('a, 'b) t

Reciprocal.

Sourceval op_where : (int, Nx_core.Dtype.uint8_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Select from if_true or if_false based on a boolean tensor.

Sourceval op_reduce_sum : axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

Sum over axes. Keeps reduced dimensions if keepdims is true.

Sourceval op_reduce_max : axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

Maximum over axes. Keeps reduced dimensions if keepdims is true.

Sourceval op_reduce_prod : axes:int array -> keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

Product over axes. Keeps reduced dimensions if keepdims is true.

Sourceval op_associative_scan : axis:int -> op:[ `Sum | `Prod | `Max | `Min ] -> ('a, 'b) t -> ('a, 'b) t

Inclusive scan along axis using the associative operation op.

Sourceval op_expand : ('a, 'b) t -> Nx_core.Symbolic_shape.t -> ('a, 'b) t

Broadcast dimensions of size 1 to a new shape.

Sourceval op_reshape : ('a, 'b) t -> Nx_core.Symbolic_shape.t -> ('a, 'b) t

Change the logical shape without moving data.

Sourceval op_permute : ('a, 'b) t -> int array -> ('a, 'b) t

Reorder dimensions according to axes.

Sourceval op_shrink : ('a, 'b) t -> (int * int) array -> ('a, 'b) t

Slice according to the given start/stop pairs.

Sourceval op_flip : ('a, 'b) t -> bool array -> ('a, 'b) t

Flip dimensions where the boolean array is true.

Sourceval op_pad : ('a, 'b) t -> (int * int) array -> 'a -> ('a, 'b) t

Pad with fill_value using the given configuration.

Sourceval op_cat : ('a, 'b) t list -> int -> ('a, 'b) t

Concatenate tensors along axis.

Sourceval op_cast : ('a, 'b) t -> ('c, 'd) Nx_core.Dtype.t -> ('c, 'd) t

Cast elements to target_dtype.

Sourceval op_contiguous : ('a, 'b) t -> ('a, 'b) t

Return a C-contiguous tensor. May copy.

Sourceval op_copy : ('a, 'b) t -> ('a, 'b) t

Duplicate t. Result has its own buffer.

Sourceval op_assign : ('a, 'b) t -> ('a, 'b) t -> unit

Store src into dst at the given logical indices.

Sourceval op_threefry : (int32, Nx_core.Dtype.int32_elt) t -> (int32, Nx_core.Dtype.int32_elt) t -> (int32, Nx_core.Dtype.int32_elt) t

Threefry random number generator.

Sourceval op_gather : ('a, 'b) t -> (int32, Nx_core.Dtype.int32_elt) t -> int -> ('a, 'b) t

Gather elements from data along axis using indices. Output shape matches indices. Ranks of data and indices must match. Sizes of indices dims != axis must be <= data corresponding dims.

Sourceval op_scatter : ?mode:[ `Set | `Add ] -> ?unique_indices:bool -> ('a, 'b) t -> (int32, Nx_core.Dtype.int32_elt) t -> ('a, 'b) t -> int -> ('a, 'b) t

Scatter updates into a new tensor shaped like data_template along axis using indices. Returns a new tensor.

  • mode specifies how to handle duplicate indices:
  • `Set (default): last update wins
  • `Add: accumulate updates at duplicate indices
  • unique_indices: hint that indices are unique (optimization)
Sourceval op_unfold : ('a, 'b) t -> kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) t

Unfold (im2col) operation. Extracts sliding local blocks from a batched input tensor. For an input of shape (N, C, *spatial_dims), produces output of shape (N, C * prod(kernel_size), L) where L is the number of blocks. Works for any number of spatial dimensions.

Sourceval op_fold : ('a, 'b) t -> output_size:int array -> kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) t

Fold (col2im) operation. Combines an array of sliding local blocks into a tensor. For an input of shape (N, C * prod(kernel_size), L), produces output of shape (N, C, *output_size). Inverse of unfold. Overlapping values are summed. Works for any number of spatial dimensions.

Sourceval op_matmul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Matrix multiplication. For 2D tensors, computes standard matrix multiplication. For higher dimensions, performs batched matrix multiplication on the last two dimensions, broadcasting batch dimensions as needed. The last dimension of the first tensor must match the second-to-last dimension of the second tensor.

Sourceval op_fft : (Complex.t, 'b) t -> axes:int array -> (Complex.t, 'b) t

Compute the discrete Fourier transform (DFT) of the input tensor.

Sourceval op_ifft : (Complex.t, 'b) t -> axes:int array -> (Complex.t, 'b) t

Compute the inverse discrete Fourier transform (IDFT) of the input tensor.

Sourceval op_rfft : (float, 'a) t -> dtype:(Complex.t, 'b) Nx_core.Dtype.t -> axes:int array -> (Complex.t, 'b) t

Compute the real-valued discrete Fourier transform (RDFT) of the input tensor.

Sourceval op_irfft : (Complex.t, 'a) t -> dtype:(float, 'b) Nx_core.Dtype.t -> axes:int array -> s:int array option -> (float, 'b) t

Compute the inverse real-valued discrete Fourier transform (IRDFT) of the input tensor.

Sourceval op_cholesky : upper:bool -> ('a, 'b) t -> ('a, 'b) t

Cholesky decomposition of a positive-definite matrix.

  • upper: If true, returns upper triangular factor; else lower (default).
  • Input: Square matrix A (batched).
  • Output: Triangular factor L or U such that A = L*L^T or A = U^T*U.
  • Raises if input is not positive-definite.
Sourceval op_qr : reduced:bool -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t

QR decomposition.

  • reduced: If true (default), returns economy/reduced QR; else full QR.
  • Input: m x n matrix A (batched).
  • Output: (Q, R) where A = Q*R, Q orthogonal, R upper triangular.
Sourceval op_svd : full_matrices:bool -> ('a, 'b) t -> ('a, 'b) t * (float, Nx_core.Dtype.float64_elt) t * ('a, 'b) t

Singular value decomposition.

  • full_matrices: If false (default), returns thin SVD; else full.
  • Input: m x n matrix A (batched).
  • Output: (U, S, V^H) where A = U*S*V^H.
  • S is 1D vector of singular values in descending order, always float64.
Sourceval op_eig : vectors:bool -> ('a, 'b) t -> (Complex.t, Nx_core.Dtype.complex64_elt) t * (Complex.t, Nx_core.Dtype.complex64_elt) t option

General eigenvalue decomposition.

  • vectors: If true (default), computes eigenvectors.
  • Input: Square matrix A (batched).
  • Output: (eigenvalues, optional eigenvectors) always as complex64.
Sourceval op_eigh : vectors:bool -> ('a, 'b) t -> (float, Nx_core.Dtype.float64_elt) t * ('a, 'b) t option

Symmetric/Hermitian eigenvalue decomposition.

  • vectors: If true (default), computes eigenvectors.
  • Input: Symmetric (real) or Hermitian (complex) matrix A (batched).
  • Output: (eigenvalues as float64, eigenvectors same type as input).
Sourceval op_triangular_solve : upper:bool -> transpose:bool -> unit_diag:bool -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

Solve triangular system A*x = b or A^T*x = b.

  • upper: If true, A is upper triangular; else lower.
  • transpose: If true, solve A^T*x = b; else A*x = b.
  • unit_diag: If true, assume diagonal of A is all 1s.
  • Input: Triangular matrix A, right-hand side b (batched).
  • Output: Solution x.
Sourceval op_as_strided : ('a, 'b) t -> Nx_core.Symbolic_shape.t -> int array -> int -> ('a, 'b) t

Create a strided view of the input tensor with the given shape, strides (in elements), and offset (in elements). Backends that support arbitrary strided views (e.g., native with Bigarray) can implement this as zero-copy. Other backends may fall back to copying data if necessary. Raises if the view would access out-of-bounds memory.

Sourceval create_context : unit -> context