package nx

  1. Overview
  2. Docs

doc/nx/Nx/index.html

Module NxSource

N-dimensional arrays.

Nx provides n-dimensional arrays (tensors) with NumPy-like semantics. A tensor ('a, 'b) t holds elements of OCaml type 'a stored in a buffer with element kind 'b.

Tensors, views, and contiguity. A tensor is a view over a flat buffer described by a shape, strides, and an offset. Operations that only rearrange metadata (reshape, transpose, slice, …) return views in O(1) without copying data. Use is_c_contiguous to test whether elements are laid out contiguously in row-major order, and contiguous to obtain a contiguous copy when needed.

Broadcasting. Binary operations automatically broadcast operands whose shapes differ: dimensions are aligned from the right and each pair must be equal or one of them must be 1.

The ?out convention. Many operations accept an optional ?out tensor. When provided, the result is written into out instead of allocating a fresh tensor; the shape of out must match the result shape.

Types

Sourcetype ('a, 'b) t = ('a, 'b) Nx_effect.t

The type for tensors with OCaml element type 'a and buffer element kind 'b.

Element kinds

Witnesses for the buffer element representation. Used as the second type parameter of t.

Sourcetype float16_elt = Nx_buffer.float16_elt
Sourcetype float32_elt = Nx_buffer.float32_elt
Sourcetype float64_elt = Nx_buffer.float64_elt
Sourcetype bfloat16_elt = Nx_buffer.bfloat16_elt
Sourcetype float8_e4m3_elt = Nx_buffer.float8_e4m3_elt
Sourcetype float8_e5m2_elt = Nx_buffer.float8_e5m2_elt
Sourcetype int32_elt = Nx_buffer.int32_elt
Sourcetype uint32_elt = Nx_buffer.uint32_elt
Sourcetype int64_elt = Nx_buffer.int64_elt
Sourcetype uint64_elt = Nx_buffer.uint64_elt
Sourcetype complex32_elt = Nx_buffer.complex32_elt
Sourcetype complex64_elt = Nx_buffer.complex64_elt
Sourcetype bool_elt = Nx_buffer.bool_elt

Data types

Sourcetype ('a, 'b) dtype = ('a, 'b) Nx_core.Dtype.t =
  1. | Float16 : (float, float16_elt) dtype
  2. | Float32 : (float, float32_elt) dtype
  3. | Float64 : (float, float64_elt) dtype
  4. | BFloat16 : (float, bfloat16_elt) dtype
  5. | Float8_e4m3 : (float, float8_e4m3_elt) dtype
  6. | Float8_e5m2 : (float, float8_e5m2_elt) dtype
  7. | Int4 : (int, int4_elt) dtype
  8. | UInt4 : (int, uint4_elt) dtype
  9. | Int8 : (int, int8_elt) dtype
  10. | UInt8 : (int, uint8_elt) dtype
  11. | Int16 : (int, int16_elt) dtype
  12. | UInt16 : (int, uint16_elt) dtype
  13. | Int32 : (int32, int32_elt) dtype
  14. | UInt32 : (int32, uint32_elt) dtype
  15. | Int64 : (int64, int64_elt) dtype
  16. | UInt64 : (int64, uint64_elt) dtype
  17. | Complex64 : (Complex.t, complex32_elt) dtype
  18. | Complex128 : (Complex.t, complex64_elt) dtype
  19. | Bool : (bool, bool_elt) dtype
    (*

    The type for data type descriptors. A ('a, 'b) dtype links the OCaml element type 'a to its buffer representation 'b.

    *)

Tensor aliases

Sourcetype float16_t = (float, float16_elt) t
Sourcetype float32_t = (float, float32_elt) t
Sourcetype float64_t = (float, float64_elt) t
Sourcetype bfloat16_t = (float, bfloat16_elt) t
Sourcetype float8_e4m3_t = (float, float8_e4m3_elt) t
Sourcetype float8_e5m2_t = (float, float8_e5m2_elt) t
Sourcetype int4_t = (int, int4_elt) t
Sourcetype uint4_t = (int, uint4_elt) t
Sourcetype int8_t = (int, int8_elt) t
Sourcetype uint8_t = (int, uint8_elt) t
Sourcetype int16_t = (int, int16_elt) t
Sourcetype uint16_t = (int, uint16_elt) t
Sourcetype int32_t = (int32, int32_elt) t
Sourcetype uint32_t = (int32, uint32_elt) t
Sourcetype int64_t = (int64, int64_elt) t
Sourcetype uint64_t = (int64, uint64_elt) t
Sourcetype complex64_t = (Complex.t, complex32_elt) t
Sourcetype complex128_t = (Complex.t, complex64_elt) t
Sourcetype bool_t = (bool, bool_elt) t

Data type values

Sourceval float16 : (float, float16_elt) dtype
Sourceval float32 : (float, float32_elt) dtype
Sourceval float64 : (float, float64_elt) dtype
Sourceval bfloat16 : (float, bfloat16_elt) dtype
Sourceval float8_e4m3 : (float, float8_e4m3_elt) dtype
Sourceval float8_e5m2 : (float, float8_e5m2_elt) dtype
Sourceval int4 : (int, int4_elt) dtype
Sourceval uint4 : (int, uint4_elt) dtype
Sourceval int8 : (int, int8_elt) dtype
Sourceval uint8 : (int, uint8_elt) dtype
Sourceval int16 : (int, int16_elt) dtype
Sourceval uint16 : (int, uint16_elt) dtype
Sourceval int32 : (int32, int32_elt) dtype
Sourceval uint32 : (int32, uint32_elt) dtype
Sourceval int64 : (int64, int64_elt) dtype
Sourceval uint64 : (int64, uint64_elt) dtype
Sourceval complex128 : (Complex.t, complex64_elt) dtype
Sourceval bool : (bool, bool_elt) dtype

Index specifications

Sourcetype index =
  1. | I of int
    (*

    I i selects a single index, reducing the dimension.

    *)
  2. | L of int list
    (*

    L [i0; i1; …] gathers the listed indices.

    *)
  3. | R of int * int
    (*

    R (start, stop) selects the half-open range [start, stop).

    *)
  4. | Rs of int * int * int
    (*

    Rs (start, stop, step) selects a strided range.

    *)
  5. | A
    (*

    A selects the entire axis. This is the default for axes not covered by a slice specification.

    *)
  6. | M of (bool, bool_elt) t
    (*

    M mask selects positions where mask is true.

    *)
  7. | N
    (*

    N inserts a new axis of size 1 (does not consume an input axis).

    *)

The type for index specifications used by slice and set_slice.

Properties

Sourceval data : ('a, 'b) t -> ('a, 'b) Nx_buffer.t

data t is the underlying flat buffer of t.

The buffer is shared: mutations through the buffer are visible through t and vice-versa. The buffer may be larger than the tensor's logical extent when t is a strided view.

Sourceval shape : ('a, 'b) t -> int array

shape t is the dimensions of t. A scalar tensor has shape |\||.

Sourceval dtype : ('a, 'b) t -> ('a, 'b) dtype

dtype t is the data type of t.

Sourceval strides : ('a, 'b) t -> int array

strides t is the byte stride for each dimension of t.

Raises Invalid_argument if t does not have computable strides (e.g. after certain non-contiguous view operations). Use is_c_contiguous or call contiguous first.

See also stride.

Sourceval stride : int -> ('a, 'b) t -> int

stride i t is the byte stride of dimension i.

Raises Invalid_argument if i is out of bounds or t does not have computable strides.

See also strides.

Sourceval dims : ('a, 'b) t -> int array
Sourceval dim : int -> ('a, 'b) t -> int

dim i t is the size of dimension i.

Raises Invalid_argument if i is out of bounds.

Sourceval ndim : ('a, 'b) t -> int

ndim t is the number of dimensions of t.

Sourceval itemsize : ('a, 'b) t -> int

itemsize t is the number of bytes per element.

Sourceval size : ('a, 'b) t -> int

size t is the total number of elements.

Sourceval numel : ('a, 'b) t -> int

numel t is size.

Sourceval nbytes : ('a, 'b) t -> int

nbytes t is size t * itemsize t.

Sourceval offset : ('a, 'b) t -> int

offset t is the element offset of t in its underlying buffer.

Sourceval is_c_contiguous : ('a, 'b) t -> bool

is_c_contiguous t is true iff t's elements are laid out contiguously in row-major (C) order.

See also contiguous.

Sourceval to_bigarray : ('a, 'b) t -> ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t

to_bigarray t is a contiguous bigarray with the same shape and data as t. Always copies.

Raises Invalid_argument if t's dtype is an extended type not supported by Bigarray.

See also of_bigarray.

Sourceval to_buffer : ('a, 'b) t -> ('a, 'b) Nx_buffer.t

to_buffer t is a flat, contiguous buffer of t's data.

Returns the underlying buffer directly when t is already contiguous with zero offset and matching size; copies otherwise.

Sourceval to_array : ('a, 'b) t -> 'a array

to_array t is a fresh OCaml array containing the elements of t in row-major order. Always copies.

  # let t =
      create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |]
    in
    to_array t
  - : int32 array = [|1l; 2l; 3l; 4l|]

Creation

Sourceval create : ('a, 'b) dtype -> int array -> 'a array -> ('a, 'b) t

create dtype shape data is a tensor of the given dtype and shape initialised from data in row-major order.

Raises Invalid_argument if Array.length data does not equal the product of shape.

  # create float32 [| 2; 3 |]
      [| 1.; 2.; 3.; 4.; 5.; 6. |]
  - : (float, float32_elt) t = float32 [2; 3] [[1, 2, 3],
                                               [4, 5, 6]]
Sourceval init : ('a, 'b) dtype -> int array -> (int array -> 'a) -> ('a, 'b) t

init dtype shape f is a tensor where the element at multi-index i is f i.

  # init int32 [| 2; 3 |]
      (fun i -> Int32.of_int (i.(0) + i.(1)))
  - : (int32, int32_elt) t = int32 [2; 3] [[0, 1, 2],
                                           [1, 2, 3]]
Sourceval empty : ('a, 'b) dtype -> int array -> ('a, 'b) t

empty dtype shape is an uninitialized tensor.

Warning. Elements contain arbitrary values until written.

Sourceval full : ('a, 'b) dtype -> int array -> 'a -> ('a, 'b) t

full dtype shape v is a tensor filled with v.

  # full float32 [| 2; 3 |] 3.14
  - : (float, float32_elt) t = float32 [2; 3]
  [[3.14, 3.14, 3.14],
   [3.14, 3.14, 3.14]]
Sourceval ones : ('a, 'b) dtype -> int array -> ('a, 'b) t

ones dtype shape is a tensor filled with ones.

Sourceval zeros : ('a, 'b) dtype -> int array -> ('a, 'b) t

zeros dtype shape is a tensor filled with zeros.

Sourceval scalar : ('a, 'b) dtype -> 'a -> ('a, 'b) t

scalar dtype v is a 0-dimensional tensor containing v. The result has shape |\||.

Sourceval empty_like : ('a, 'b) t -> ('a, 'b) t

empty_like t is empty with the same dtype and shape as t.

Sourceval full_like : ('a, 'b) t -> 'a -> ('a, 'b) t

full_like t v is full with the same dtype and shape as t.

Sourceval ones_like : ('a, 'b) t -> ('a, 'b) t

ones_like t is ones with the same dtype and shape as t.

Sourceval zeros_like : ('a, 'b) t -> ('a, 'b) t

zeros_like t is zeros with the same dtype and shape as t.

Sourceval scalar_like : ('a, 'b) t -> 'a -> ('a, 'b) t

scalar_like t v is scalar with the same dtype as t.

Sourceval eye : ?m:int -> ?k:int -> ('a, 'b) dtype -> int -> ('a, 'b) t

eye ?m ?k dtype n is an n × m matrix with ones on the k-th diagonal and zeros elsewhere. m defaults to n. k defaults to 0 (main diagonal); positive k selects an upper diagonal, negative k a lower one.

  # eye int32 3
  - : (int32, int32_elt) t = int32 [3; 3] [[1, 0, 0],
                                           [0, 1, 0],
                                           [0, 0, 1]]
  # eye ~k:1 int32 3
  - : (int32, int32_elt) t = int32 [3; 3] [[0, 1, 0],
                                           [0, 0, 1],
                                           [0, 0, 0]]

See also identity, diag.

Sourceval identity : ('a, 'b) dtype -> int -> ('a, 'b) t

identity dtype n is eye dtype n.

Sourceval diag : ?k:int -> ('a, 'b) t -> ('a, 'b) t

diag ?k v extracts or constructs a diagonal.

When v is 1-D, returns a 2-D tensor with v on the k-th diagonal. When v is 2-D, returns the k-th diagonal as a 1-D tensor. k defaults to 0.

Raises Invalid_argument if v is not 1-D or 2-D.

  # let v = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    diag v
  - : (int32, int32_elt) t = int32 [3; 3] [[1, 0, 0],
                                           [0, 2, 0],
                                           [0, 0, 3]]
  # let x =
      arange int32 0 9 1 |> reshape [| 3; 3 |]
    in
    diag x
  - : (int32, int32_elt) t = [0, 4, 8]

See also eye, diagonal.

Sourceval arange : ('a, 'b) dtype -> int -> int -> int -> ('a, 'b) t

arange dtype start stop step is a 1-D tensor of values from start (inclusive) to stop (exclusive) with stride step.

Raises Invalid_argument if step = 0.

  # arange int32 0 10 2
  - : (int32, int32_elt) t = int32 [5] [0, 2, ..., 6, 8]
  # arange int32 5 0 (-1)
  - : (int32, int32_elt) t = int32 [5] [5, 4, ..., 2, 1]

See also arange_f, linspace.

Sourceval arange_f : (float, 'a) dtype -> float -> float -> float -> (float, 'a) t

arange_f dtype start stop step is like arange for floating-point ranges.

Raises Invalid_argument if step = 0.0.

  # arange_f float32 0. 1. 0.2
  - : (float, float32_elt) t = float32 [5] [0, 0.2, ..., 0.6, 0.8]

See also arange, linspace.

Sourceval linspace : ('a, 'b) dtype -> ?endpoint:bool -> float -> float -> int -> ('a, 'b) t

linspace dtype ?endpoint start stop n is n values evenly spaced from start to stop. endpoint defaults to true (include stop).

Raises Invalid_argument if n is negative.

  # linspace float32 0. 10. 5
  - : (float, float32_elt) t = float32 [5] [0, 2.5, ..., 7.5, 10]
  # linspace float32 ~endpoint:false 0. 10. 5
  - : (float, float32_elt) t = float32 [5] [0, 2, ..., 6, 8]

See also logspace, geomspace.

Sourceval logspace : (float, 'a) dtype -> ?endpoint:bool -> ?base:float -> float -> float -> int -> (float, 'a) t

logspace dtype ?endpoint ?base start stop n is n values evenly spaced on a logarithmic scale: base{^x} where x ranges from start to stop. endpoint defaults to true. base defaults to 10.0.

Raises Invalid_argument if n is negative.

  # logspace float32 0. 2. 3
  - : (float, float32_elt) t = [1, 10, 100]
  # logspace float32 ~base:2.0 0. 3. 4
  - : (float, float32_elt) t = [1, 2, 4, 8]

See also linspace, geomspace.

Sourceval geomspace : (float, 'a) dtype -> ?endpoint:bool -> float -> float -> int -> (float, 'a) t

geomspace dtype ?endpoint start stop n is n values evenly spaced on a geometric (multiplicative) scale. endpoint defaults to true.

Raises Invalid_argument if start or stop is not positive.

  # geomspace float32 1. 1000. 4
  - : (float, float32_elt) t = [1, 10, 100, 1000]

See also linspace, logspace.

Sourceval meshgrid : ?indexing:[ `xy | `ij ] -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t

meshgrid ?indexing x y is a pair of 2-D coordinate grids built from 1-D arrays x and y. indexing defaults to `xy (Cartesian: X varies along columns, Y along rows). With `ij (matrix), X varies along rows, Y along columns.

Raises Invalid_argument if x or y is not 1-D.

  # let x = linspace float32 0. 2. 3 in
    let y = linspace float32 0. 1. 2 in
    meshgrid x y
  - : (float, float32_elt) t * (float, float32_elt) t =
  (float32 [2; 3] [[0, 1, 2],
                   [0, 1, 2]], float32 [2; 3] [[0, 0, 0],
                                               [1, 1, 1]])
Sourceval tril : ?k:int -> ('a, 'b) t -> ('a, 'b) t

tril ?k x is the lower-triangular part of x with elements above the k-th diagonal set to zero. k defaults to 0.

Raises Invalid_argument if x has fewer than 2 dimensions.

See also triu.

Sourceval triu : ?k:int -> ('a, 'b) t -> ('a, 'b) t

triu ?k x is the upper-triangular part of x with elements below the k-th diagonal set to zero. k defaults to 0.

Raises Invalid_argument if x has fewer than 2 dimensions.

See also tril.

Sourceval of_bigarray : ('a, 'b, Bigarray.c_layout) Bigarray.Genarray.t -> ('a, 'b) t

of_bigarray ba is a tensor sharing memory with ba.

Zero-copy: mutations through either are visible to both.

See also to_bigarray.

Sourceval of_buffer : ('a, 'b) Nx_buffer.t -> shape:int array -> ('a, 'b) t

of_buffer buf ~shape is a tensor viewing buf with the given shape. The product of shape must equal the buffer length.

Sourceval one_hot : num_classes:int -> ('a, 'b) t -> (int, uint8_elt) t

one_hot ~num_classes indices is a one-hot encoded tensor.

Appends a new trailing dimension of size num_classes. Values in indices must lie in [0, num_classes). Out-of-range indices produce all-zero rows.

Raises Invalid_argument if indices is not an integer dtype or num_classes <= 0.

  # let idx =
      create int32 [| 3 |] [| 0l; 1l; 3l |]
    in
    one_hot ~num_classes:4 idx
  - : (int, uint8_elt) t = uint8 [3; 4]
  [[1, 0, 0, 0],
   [0, 1, 0, 0],
   [0, 0, 0, 1]]

Random number generation

Sampling functions use the implicit RNG state managed by Rng. Wrap calls in Rng.run for reproducibility:

  Rng.run ~seed:42 (fun () -> rand float32 [| 3 |])
Sourcemodule Rng : sig ... end

Splittable RNG keys and implicit key management.

Sourceval rand : ('a, 'b) dtype -> int array -> ('a, 'b) t

rand dtype shape samples uniformly from [0, 1).

Raises Invalid_argument if dtype is not a float type.

Sourceval randn : ('a, 'b) dtype -> int array -> ('a, 'b) t

randn dtype shape samples from the standard normal distribution (mean 0, variance 1) via the Box–Muller transform.

Raises Invalid_argument if dtype is not a float type.

Sourceval randint : ('a, 'b) dtype -> ?high:int -> int array -> int -> ('a, 'b) t

randint dtype ?high shape low samples integers uniformly from [low, high). high defaults to 10.

Raises Invalid_argument if dtype is not an integer type or low >= high.

Sourceval bernoulli : p:float -> int array -> bool_t

bernoulli ~p shape samples booleans that are true with probability p.

Raises Invalid_argument if p is not in [0, 1].

Sourceval permutation : int -> int32_t

permutation n is a random permutation of [0, n-1].

Raises Invalid_argument if n <= 0.

Sourceval shuffle : ('a, 'b) t -> ('a, 'b) t

shuffle t is a copy of t with the first axis randomly permuted. No-op on scalars.

Sourceval categorical : ?axis:int -> ?shape:int array -> (float, 'a) t -> int32_t

categorical ?axis ?shape logits samples category indices from unnormalised log-probabilities using the Gumbel-max trick. axis defaults to -1 (last axis). shape prepends extra batch dimensions.

Raises Invalid_argument if logits is not a float type or axis is out of bounds.

Sourceval truncated_normal : ('a, 'b) dtype -> lower:float -> upper:float -> int array -> ('a, 'b) t

truncated_normal dtype ~lower ~upper shape samples from a standard normal distribution truncated to [lower, upper].

Raises Invalid_argument if dtype is not a float type or lower >= upper.

Shape manipulation

Sourceval reshape : int array -> ('a, 'b) t -> ('a, 'b) t

reshape shape t is a view of t with the given shape.

At most one dimension may be -1; it is inferred from the total number of elements. The product of shape must equal size t.

Raises Invalid_argument if shape is incompatible or contains more than one -1.

  # create int32 [| 6 |] [| 1l; 2l; 3l; 4l; 5l; 6l |]
    |> reshape [| 2; 3 |]
  - : (int32, int32_elt) t = int32 [2; 3] [[1, 2, 3],
                                           [4, 5, 6]]
  # create int32 [| 6 |] [| 1l; 2l; 3l; 4l; 5l; 6l |]
    |> reshape [| 3; -1 |]
  - : (int32, int32_elt) t = int32 [3; 2] [[1, 2],
                                           [3, 4],
                                           [5, 6]]

See also flatten, unflatten, ravel.

Sourceval broadcast_to : int array -> ('a, 'b) t -> ('a, 'b) t

broadcast_to shape t is a view of t broadcast to shape.

Dimensions are aligned from the right; each dimension of t must be 1 or equal to the corresponding target dimension. Broadcast dimensions have zero byte-stride (no copy).

Raises Invalid_argument if the shapes are incompatible.

  # create int32 [| 1; 3 |] [| 1l; 2l; 3l |]
    |> broadcast_to [| 3; 3 |]
  - : (int32, int32_elt) t = int32 [3; 3] [[1, 2, 3],
                                           [1, 2, 3],
                                           [1, 2, 3]]

See also broadcasted, expand.

Sourceval broadcasted : ?reverse:bool -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t

broadcasted ?reverse t1 t2 is (t1', t2') where both are broadcast to their common shape. When reverse is true (default false), returns (t2', t1').

Raises Invalid_argument if the shapes are incompatible.

See also broadcast_to, broadcast_arrays.

Sourceval expand : int array -> ('a, 'b) t -> ('a, 'b) t

expand shape t is like broadcast_to but -1 in shape preserves the corresponding dimension of t.

Raises Invalid_argument if any dimension in shape is negative (other than -1).

  # ones float32 [| 1; 4; 1 |]
    |> expand [| 3; -1; 5 |] |> shape
  - : int array = [|3; 4; 5|]

See also broadcast_to.

Sourceval flatten : ?start_dim:int -> ?end_dim:int -> ('a, 'b) t -> ('a, 'b) t

flatten ?start_dim ?end_dim t collapses dimensions start_dim through end_dim (inclusive) into a single dimension. start_dim defaults to 0. end_dim defaults to -1 (last). Negative indices count from the end.

Raises Invalid_argument if indices are out of bounds.

  # zeros float32 [| 2; 3; 4 |] |> flatten |> shape
  - : int array = [|24|]
  # zeros float32 [| 2; 3; 4; 5 |]
    |> flatten ~start_dim:1 ~end_dim:2 |> shape
  - : int array = [|2; 12; 5|]

See also unflatten, ravel.

Sourceval unflatten : int -> int array -> ('a, 'b) t -> ('a, 'b) t

unflatten dim sizes t expands dimension dim into multiple dimensions given by sizes. At most one element of sizes may be -1 (inferred). The product of sizes must equal the size of dimension dim.

Raises Invalid_argument if the product mismatches or dim is out of bounds.

  # zeros float32 [| 2; 12; 5 |]
    |> unflatten 1 [| 3; 4 |] |> shape
  - : int array = [|2; 3; 4; 5|]

See also flatten.

Sourceval ravel : ('a, 'b) t -> ('a, 'b) t

ravel t is t reshaped to 1-D. Returns a view when possible.

Raises Invalid_argument if t cannot be flattened without copying; call contiguous first.

See also flatten, contiguous.

Sourceval squeeze : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t

squeeze ?axes t removes dimensions of size 1. When axes is given, only those axes are removed. Negative indices count from the end.

Raises Invalid_argument if a specified axis does not have size 1.

  # ones float32 [| 1; 3; 1; 4 |]
    |> squeeze |> shape
  - : int array = [|3; 4|]
  # ones float32 [| 1; 3; 1; 4 |]
    |> squeeze ~axes:[ 0 ] |> shape
  - : int array = [|3; 1; 4|]

See also unsqueeze.

Sourceval unsqueeze : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t

unsqueeze ?axes t inserts dimensions of size 1 at the positions listed in axes. Positions refer to the result tensor.

Raises Invalid_argument if axes is not specified, contains duplicates, or values are out of bounds.

  # create float32 [| 3 |] [| 1.; 2.; 3. |]
    |> unsqueeze ~axes:[ 0; 2 ] |> shape
  - : int array = [|1; 3; 1|]

See also squeeze, expand_dims.

Sourceval squeeze_axis : int -> ('a, 'b) t -> ('a, 'b) t

squeeze_axis i t removes dimension i if its size is 1.

Raises Invalid_argument if dimension i is not 1.

See also squeeze.

Sourceval unsqueeze_axis : int -> ('a, 'b) t -> ('a, 'b) t

unsqueeze_axis i t inserts a dimension of size 1 at position i.

See also unsqueeze.

Sourceval expand_dims : int list -> ('a, 'b) t -> ('a, 'b) t

expand_dims axes t is unsqueeze ~axes t.

Sourceval transpose : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t

transpose ?axes t permutes the dimensions of t.

axes must be a permutation of [0; …; ndim t - 1]. When omitted, reverses all dimensions. Returns a view (no copy).

Raises Invalid_argument if axes is not a valid permutation.

  # create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |]
    |> transpose
  - : (int32, int32_elt) t = int32 [3; 2] [[1, 4],
                                           [2, 5],
                                           [3, 6]]

See also matrix_transpose, moveaxis, swapaxes.

Sourceval flip : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t

flip ?axes t reverses elements along the given axes. When omitted, flips all dimensions.

Raises Invalid_argument if any axis is out of bounds.

  # create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |]
    |> flip ~axes:[ 1 ]
  - : (int32, int32_elt) t = int32 [2; 3] [[3, 2, 1],
                                           [6, 5, 4]]
Sourceval moveaxis : int -> int -> ('a, 'b) t -> ('a, 'b) t

moveaxis src dst t moves dimension src to position dst.

Raises Invalid_argument if either index is out of bounds.

See also transpose, swapaxes.

Sourceval swapaxes : int -> int -> ('a, 'b) t -> ('a, 'b) t

swapaxes a1 a2 t exchanges dimensions a1 and a2.

Raises Invalid_argument if either index is out of bounds.

See also transpose, moveaxis.

Sourceval roll : ?axis:int -> int -> ('a, 'b) t -> ('a, 'b) t

roll ?axis shift t shifts elements along axis by shift positions, wrapping around. When axis is omitted, operates on the flattened tensor. Negative shift rolls backward.

Raises Invalid_argument if axis is out of bounds.

  # create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |]
    |> roll 2
  - : (int32, int32_elt) t = int32 [5] [4, 5, ..., 2, 3]
Sourceval pad : (int * int) array -> 'a -> ('a, 'b) t -> ('a, 'b) t

pad widths value t pads t with value. widths.(i) is (before, after) for dimension i.

Raises Invalid_argument if Array.length widths does not match ndim t or any width is negative.

  # create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |]
    |> pad [| (1, 1); (1, 1) |] 0. |> shape
  - : int array = [|4; 4|]

See also shrink.

Sourceval shrink : (int * int) array -> ('a, 'b) t -> ('a, 'b) t

shrink ranges t extracts a slice where ranges.(i) is (start, stop) (exclusive) for dimension i. Returns a view.

  # create int32 [| 3; 3 |]
      [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |]
    |> shrink [| (1, 3); (0, 2) |]
  - : (int32, int32_elt) t = int32 [2; 2] [[4, 5],
                                           [7, 8]]

See also pad.

Sourceval tile : int array -> ('a, 'b) t -> ('a, 'b) t

tile reps t is t repeated according to reps. reps.(i) gives the repetition count along dimension i. If reps is longer than ndim t, dimensions are prepended.

Raises Invalid_argument if any repetition count is negative.

  # create int32 [| 1; 2 |] [| 1l; 2l |]
    |> tile [| 2; 3 |]
  - : (int32, int32_elt) t = int32 [2; 6] [[1, 2, ..., 1, 2],
                                           [1, 2, ..., 1, 2]]

See also repeat.

Sourceval repeat : ?axis:int -> int -> ('a, 'b) t -> ('a, 'b) t

repeat ?axis n t repeats each element n times along axis. When axis is omitted, operates on the flattened tensor.

Raises Invalid_argument if n is negative or axis is out of bounds.

  # create int32 [| 3 |] [| 1l; 2l; 3l |]
    |> repeat 2
  - : (int32, int32_elt) t = int32 [6] [1, 1, ..., 3, 3]

See also tile.

Combining and splitting

Sourceval concatenate : ?axis:int -> ('a, 'b) t list -> ('a, 'b) t

concatenate ?axis ts joins tensors along an existing axis. All tensors must have the same shape except on the concatenation axis. When axis is omitted, every tensor is flattened first. Always copies.

Raises Invalid_argument if the list is empty or shapes are incompatible.

  # let a = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
    let b = create int32 [| 1; 2 |] [| 5l; 6l |] in
    concatenate ~axis:0 [ a; b ]
  - : (int32, int32_elt) t = int32 [3; 2] [[1, 2],
                                           [3, 4],
                                           [5, 6]]

See also stack, vstack, hstack.

Sourceval stack : ?axis:int -> ('a, 'b) t list -> ('a, 'b) t

stack ?axis ts joins tensors along a new axis. All tensors must have identical shape. axis defaults to 0. Negative values count from the end of the result shape.

Raises Invalid_argument if the list is empty, shapes differ, or axis is out of bounds.

  # let a = create int32 [| 2 |] [| 1l; 2l |] in
    let b = create int32 [| 2 |] [| 3l; 4l |] in
    stack [ a; b ]
  - : (int32, int32_elt) t = int32 [2; 2] [[1, 2],
                                           [3, 4]]
  # let a = create int32 [| 2 |] [| 1l; 2l |] in
    let b = create int32 [| 2 |] [| 3l; 4l |] in
    stack ~axis:1 [ a; b ]
  - : (int32, int32_elt) t = int32 [2; 2] [[1, 3],
                                           [2, 4]]

See also concatenate.

Sourceval vstack : ('a, 'b) t list -> ('a, 'b) t

vstack ts stacks vertically (along axis 0). 1-D tensors are treated as row vectors (shape [1; n]).

Raises Invalid_argument if shapes are incompatible.

  # let a = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    let b = create int32 [| 3 |] [| 4l; 5l; 6l |] in
    vstack [ a; b ]
  - : (int32, int32_elt) t = int32 [2; 3] [[1, 2, 3],
                                           [4, 5, 6]]

See also hstack, dstack, concatenate.

Sourceval hstack : ('a, 'b) t list -> ('a, 'b) t

hstack ts stacks horizontally. 1-D tensors are concatenated directly; higher-D tensors concatenate along axis 1.

Raises Invalid_argument if shapes are incompatible.

  # let a = create int32 [| 2; 1 |] [| 1l; 2l |] in
    let b = create int32 [| 2; 1 |] [| 3l; 4l |] in
    hstack [ a; b ]
  - : (int32, int32_elt) t = int32 [2; 2] [[1, 3],
                                           [2, 4]]

See also vstack, dstack, concatenate.

Sourceval dstack : ('a, 'b) t list -> ('a, 'b) t

dstack ts stacks depth-wise (along axis 2). Tensors are reshaped to at least 3-D before concatenation: 1-D [n][1; n; 1], 2-D [m; n][m; n; 1].

Raises Invalid_argument if the resulting shapes are incompatible.

See also vstack, hstack, concatenate.

Sourceval broadcast_arrays : ('a, 'b) t list -> ('a, 'b) t list

broadcast_arrays ts broadcasts every tensor to their common shape. Returns views (no copies).

Raises Invalid_argument if shapes are incompatible.

See also broadcast_to, broadcasted.

Sourceval array_split : axis:int -> [< `Count of int | `Indices of int list ] -> ('a, 'b) t -> ('a, 'b) t list

array_split ~axis spec t splits t into sub-tensors.

With `Count n, divides as evenly as possible (first sections absorb extra elements). With `Indices [i0; i1; …], splits at the given indices producing [0, i0), [i0, i1), …, [ik, end).

Raises Invalid_argument if axis is out of bounds or spec is invalid.

  # create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |]
    |> array_split ~axis:0 (`Count 3)
  - : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5]]

See also split.

Sourceval split : axis:int -> int -> ('a, 'b) t -> ('a, 'b) t list

split ~axis n t splits t into n equal parts along axis.

Raises Invalid_argument if the axis size is not divisible by n.

See also array_split.

Type conversion and copying

Sourceval cast : ('c, 'd) dtype -> ('a, 'b) t -> ('c, 'd) t

cast dtype t is a copy of t with elements converted to dtype.

  # create float32 [| 3 |] [| 1.5; 2.7; 3.1 |]
    |> cast int32
  - : (int32, int32_elt) t = [1, 2, 3]

See also contiguous, copy.

Sourceval astype : ('a, 'b) dtype -> ('c, 'd) t -> ('a, 'b) t

astype dtype t is cast.

Sourceval contiguous : ('a, 'b) t -> ('a, 'b) t

contiguous t is t if it is already C-contiguous, or a fresh contiguous copy otherwise.

See also is_c_contiguous, copy.

Sourceval copy : ('a, 'b) t -> ('a, 'b) t

copy t is a deep copy of t. Always allocates new memory; the result is contiguous.

  # let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
    let y = copy x in
    set_item [ 0 ] 999. y;
    x, y
  - : (float, float32_elt) t * (float, float32_elt) t =
  ([1, 2, 3], [999, 2, 3])

See also contiguous.

Sourceval blit : ('a, 'b) t -> ('a, 'b) t -> unit

blit src dst copies the elements of src into dst in-place. Shapes must match exactly.

Raises Invalid_argument if shapes differ.

Sourceval fill : 'a -> ('a, 'b) t -> ('a, 'b) t

fill v t is a fresh copy of t with every element set to v. Does not mutate t.

Indexing and slicing

Sourceval get : int list -> ('a, 'b) t -> ('a, 'b) t

get indices t is the sub-tensor at indices, indexing from the outermost dimension inward. Returns a scalar tensor when all dimensions are indexed; otherwise a view of the remaining dimensions. Negative indices count from the end.

Raises Invalid_argument if any index is out of bounds.

  # let x =
      create int32 [| 2; 3 |]
        [| 1l; 2l; 3l; 4l; 5l; 6l |]
    in
    get [ 1 ] x
  - : (int32, int32_elt) t = [4, 5, 6]

See also item, slice.

Sourceval set : int list -> ('a, 'b) t -> ('a, 'b) t -> unit

set indices t v writes v at the position given by indices.

Raises Invalid_argument if indices are out of bounds.

Sourceval slice : index list -> ('a, 'b) t -> ('a, 'b) t

slice specs t extracts a sub-tensor using advanced indexing.

Each element of specs addresses one axis from left to right:

  • I i — single index (reduces dimension; negative from end).
  • L [i0; i1; …] — gather listed indices.
  • R (start, stop) — half-open range [start, stop).
  • Rs (start, stop, step) — strided range.
  • A — full axis (default for trailing axes).
  • M mask — boolean mask selecting positions where mask is true.
  • N — insert a new axis of size 1.

Returns a view when possible.

Raises Invalid_argument if specs are out of bounds, if step is zero, or if a mask spec is used (not yet supported).

  # let x =
      create int32 [| 3; 3 |]
        [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |]
    in
    slice [ R (0, 2); L [ 0; 2 ] ] x
  - : (int32, int32_elt) t = int32 [2; 2] [[1, 3],
                                           [4, 6]]

See also get, set_slice.

Sourceval set_slice : index list -> ('a, 'b) t -> ('a, 'b) t -> unit

set_slice specs t v writes v into the region of t selected by specs. v is broadcast if needed.

Raises Invalid_argument if N (new-axis) specs are used (not supported for writes).

See also slice.

Sourceval item : int list -> ('a, 'b) t -> 'a

item indices t is the scalar value at indices. Indices must cover all dimensions.

Raises Invalid_argument if the number of indices is wrong or any index is out of bounds.

See also get, set_item.

Sourceval set_item : int list -> 'a -> ('a, 'b) t -> unit

set_item indices v t sets the element at indices to v in-place. Indices must cover all dimensions.

Raises Invalid_argument if the number of indices is wrong or any index is out of bounds.

See also item.

Sourceval take : ?axis:int -> ?mode:[ `raise | `wrap | `clip ] -> (int32, int32_elt) t -> ('a, 'b) t -> ('a, 'b) t

take ?axis ?mode indices t gathers elements from t at indices along axis. When axis is omitted, t is flattened first. mode controls out-of-bounds indices: `raise (default) raises, `wrap uses modular indexing, `clip clamps to bounds.

Raises Invalid_argument if mode is `raise and any index is out of bounds.

  # let x =
      create int32 [| 5 |]
        [| 0l; 1l; 2l; 3l; 4l |]
    in
    take
      (create int32 [| 3 |] [| 1l; 3l; 0l |])
      x
  - : (int32, int32_elt) t = [1, 3, 0]

See also put, take_along_axis.

Sourceval take_along_axis : axis:int -> (int32, int32_elt) t -> ('a, 'b) t -> ('a, 'b) t

take_along_axis ~axis indices t gathers values from t along axis using indices. indices must match t's shape except along axis. Useful for gathering from argmax/argmin results.

Raises Invalid_argument if shapes are incompatible.

  # let x =
      create float32 [| 2; 3 |]
        [| 4.; 1.; 2.; 3.; 5.; 6. |]
    in
    let idx =
      create int32 [| 2; 1 |] [| 1l; 0l |]
    in
    take_along_axis ~axis:1 idx x
  - : (float, float32_elt) t = float32 [2; 1] [[1],
                                               [3]]

See also take, put_along_axis.

Sourceval put : ?axis:int -> indices:(int32, int32_elt) t -> values:('a, 'b) t -> ?mode:[ `raise | `wrap | `clip ] -> ('a, 'b) t -> unit

put ?axis ~indices ~values ?mode t writes values into t at positions given by indices. When axis is omitted, t is flattened first. mode defaults to `raise. Modifies t in-place.

Raises Invalid_argument if mode is `raise and any index is out of bounds.

See also take, put_along_axis, index_put.

Sourceval index_put : indices:(int32, int32_elt) t array -> values:('a, 'b) t -> ?mode:[ `raise | `wrap | `clip ] -> ('a, 'b) t -> unit

index_put ~indices ~values ?mode t writes values into t at the coordinates given by indices.

indices contains one index tensor per axis of t; they are broadcast to a common shape that determines the number of updates. values is broadcast to the same shape. Duplicate coordinates overwrite. mode defaults to `raise.

Raises Invalid_argument if the number of index tensors does not match ndim t.

  # let t = zeros float32 [| 3; 3 |] in
    let rows =
      create int32 [| 3 |] [| 0l; 2l; 1l |]
    in
    let cols =
      create int32 [| 3 |] [| 1l; 0l; 2l |]
    in
    index_put ~indices:[| rows; cols |]
      ~values:(create float32 [| 3 |]
                 [| 10.; 20.; 30. |])
      t;
    t
  - : (float, float32_elt) t = float32 [3; 3]
  [[0, 10, 0],
   [0, 0, 30],
   [20, 0, 0]]

See also put.

Sourceval put_along_axis : axis:int -> indices:(int32, int32_elt) t -> values:('a, 'b) t -> ('a, 'b) t -> unit

put_along_axis ~axis ~indices ~values t writes values into t at positions selected by indices along axis. Modifies t in-place.

Raises Invalid_argument if shapes are incompatible.

See also take_along_axis, put.

Sourceval compress : ?axis:int -> condition:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t

compress ?axis ~condition t selects elements where condition is true along axis. condition must be 1-D. When axis is omitted, t is flattened first.

Raises Invalid_argument if the condition length is incompatible.

  # let x =
      create int32 [| 5 |]
        [| 1l; 2l; 3l; 4l; 5l |]
    in
    compress
      ~condition:(create bool [| 5 |]
        [| true; false; true; false; true |])
      x
  - : (int32, int32_elt) t = [1, 3, 5]

See also extract, nonzero.

Sourceval extract : condition:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t

extract ~condition t is the 1-D tensor of elements of t where condition is true. Both are flattened before comparison.

Raises Invalid_argument if sizes differ.

See also compress, nonzero.

Sourceval nonzero : ('a, 'b) t -> (int32, int32_elt) t array

nonzero t is an array of 1-D index tensors, one per dimension, giving the coordinates of non-zero elements.

  # let x =
      create int32 [| 3; 3 |]
        [| 0l; 1l; 0l;
           2l; 0l; 3l;
           0l; 0l; 4l |]
    in
    let idx = nonzero x in
    idx.(0), idx.(1)
  - : (int32, int32_elt) t * (int32, int32_elt) t =
  ([0, 1, 1, 2], [1, 0, 2, 2])

See also argwhere.

Sourceval argwhere : ('a, 'b) t -> (int32, int32_elt) t

argwhere t is a 2-D tensor of shape [k; ndim t] whose rows are the coordinates of the k non-zero elements.

See also nonzero.

Arithmetic

Element-wise arithmetic with broadcasting. Each operation op has variants:

  • op_s t s — tensor-scalar.
  • rop_s s t — scalar-tensor (reversed operands).
Sourceval add : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

add ?out a b is the element-wise sum of a and b. out defaults to a fresh allocation.

Sourceval add_s : ?out:('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

add_s ?out t s adds scalar s to each element of t. out defaults to a fresh allocation.

Sourceval radd_s : ?out:('a, 'b) t -> 'a -> ('a, 'b) t -> ('a, 'b) t

radd_s ?out s t is add_s ?out t s.

Sourceval sub : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

sub ?out a b is the element-wise difference a - b. out defaults to a fresh allocation.

Sourceval sub_s : ?out:('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

sub_s ?out t s subtracts scalar s from each element. out defaults to a fresh allocation.

Sourceval rsub_s : ?out:('a, 'b) t -> 'a -> ('a, 'b) t -> ('a, 'b) t

rsub_s ?out s t is s - t element-wise. out defaults to a fresh allocation.

Sourceval mul : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

mul ?out a b is the element-wise product of a and b. out defaults to a fresh allocation.

Sourceval mul_s : ?out:('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

mul_s ?out t s multiplies each element by scalar s. out defaults to a fresh allocation.

Sourceval rmul_s : ?out:('a, 'b) t -> 'a -> ('a, 'b) t -> ('a, 'b) t

rmul_s ?out s t is mul_s ?out t s.

Sourceval div : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

div ?out a b is the element-wise quotient a / b. out defaults to a fresh allocation.

Float dtypes use true division. Integer dtypes truncate toward zero.

  # let x = create int32 [| 2 |] [| -7l; 8l |] in
    let y = create int32 [| 2 |] [| 2l; 2l |] in
    div x y
  - : (int32, int32_elt) t = [-3, 4]
Sourceval div_s : ?out:('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

div_s ?out t s divides each element by scalar s. out defaults to a fresh allocation.

Sourceval rdiv_s : ?out:('a, 'b) t -> 'a -> ('a, 'b) t -> ('a, 'b) t

rdiv_s ?out s t is s / t element-wise. out defaults to a fresh allocation.

Sourceval pow : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

pow ?out base exp is base raised to exp element-wise. out defaults to a fresh allocation.

Sourceval pow_s : ?out:('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

pow_s ?out t s raises each element to scalar power s. out defaults to a fresh allocation.

Sourceval rpow_s : ?out:('a, 'b) t -> 'a -> ('a, 'b) t -> ('a, 'b) t

rpow_s ?out s t is s{^t} element-wise. out defaults to a fresh allocation.

Sourceval mod_ : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

mod_ ?out a b is the element-wise remainder of a / b. out defaults to a fresh allocation.

Sourceval mod_s : ?out:('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

mod_s ?out t s is the remainder of each element divided by scalar s. out defaults to a fresh allocation.

Sourceval rmod_s : ?out:('a, 'b) t -> 'a -> ('a, 'b) t -> ('a, 'b) t

rmod_s ?out s t is s mod t element-wise. out defaults to a fresh allocation.

Sourceval neg : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

neg ?out t is the element-wise negation of t. out defaults to a fresh allocation.

Sourceval conjugate : ('a, 'b) t -> ('a, 'b) t

conjugate t is the complex conjugate of t. For complex dtypes, negates the imaginary part. For real dtypes, returns t unchanged.

Mathematical functions

Basic

Sourceval abs : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

abs ?out t is the element-wise absolute value. out defaults to a fresh allocation.

Sourceval sign : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

sign ?out t is -1, 0, or 1 according to the sign of each element. For unsigned types, returns 1 for non-zero, 0 for zero. out defaults to a fresh allocation.

  # create float32 [| 3 |] [| -2.; 0.; 3.5 |]
    |> sign
  - : (float, float32_elt) t = [-1, 0, 1]
Sourceval square : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

square ?out t is the element-wise square. out defaults to a fresh allocation.

Sourceval sqrt : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

sqrt ?out t is the element-wise square root. out defaults to a fresh allocation.

Sourceval rsqrt : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

rsqrt ?out t is the element-wise reciprocal square root (1 / sqrt t). out defaults to a fresh allocation.

Sourceval recip : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

recip ?out t is the element-wise reciprocal (1 / t). out defaults to a fresh allocation.

Exponential and logarithmic

Sourceval log : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

log ?out t is the element-wise natural logarithm. out defaults to a fresh allocation.

Sourceval log2 : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

log2 ?out t is the element-wise base-2 logarithm. out defaults to a fresh allocation.

Sourceval exp : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

exp ?out t is the element-wise exponential. out defaults to a fresh allocation.

Sourceval exp2 : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

exp2 ?out t is 2{^t} element-wise. out defaults to a fresh allocation.

Trigonometric

Sourceval sin : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

sin ?out t is the element-wise sine. out defaults to a fresh allocation.

Sourceval cos : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

cos ?out t is the element-wise cosine. out defaults to a fresh allocation.

Sourceval tan : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

tan ?out t is the element-wise tangent. out defaults to a fresh allocation.

Sourceval asin : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

asin ?out t is the element-wise arcsine. out defaults to a fresh allocation.

Sourceval acos : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

acos ?out t is the element-wise arccosine. out defaults to a fresh allocation.

Sourceval atan : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

atan ?out t is the element-wise arctangent. out defaults to a fresh allocation.

Sourceval atan2 : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

atan2 ?out y x is the element-wise two-argument arctangent, returning angles in [, π]. out defaults to a fresh allocation.

Hyperbolic

Sourceval sinh : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

sinh ?out t is the element-wise hyperbolic sine. out defaults to a fresh allocation.

Sourceval cosh : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

cosh ?out t is the element-wise hyperbolic cosine. out defaults to a fresh allocation.

Sourceval tanh : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

tanh ?out t is the element-wise hyperbolic tangent. out defaults to a fresh allocation.

Sourceval asinh : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

asinh ?out t is the element-wise inverse hyperbolic sine. out defaults to a fresh allocation.

Sourceval acosh : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

acosh ?out t is the element-wise inverse hyperbolic cosine. out defaults to a fresh allocation.

Sourceval atanh : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

atanh ?out t is the element-wise inverse hyperbolic tangent. out defaults to a fresh allocation.

Rounding

Sourceval trunc : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

trunc ?out t rounds each element toward zero. out defaults to a fresh allocation.

Sourceval ceil : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

ceil ?out t rounds each element toward positive infinity. out defaults to a fresh allocation.

Sourceval floor : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

floor ?out t rounds each element toward negative infinity. out defaults to a fresh allocation.

Sourceval round : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

round ?out t rounds each element to the nearest integer. Ties round away from zero (not banker's rounding). out defaults to a fresh allocation.

  # create float32 [| 4 |] [| 2.5; 3.5; -2.5; -3.5 |]
    |> round
  - : (float, float32_elt) t = [3, 4, -3, -4]

Other

Sourceval hypot : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

hypot ?out x y is sqrt(x² + y²) computed without intermediate overflow. out defaults to a fresh allocation.

  # hypot (scalar float32 3.) (scalar float32 4.)
    |> item []
  - : float = 5.
Sourceval lerp : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

lerp ?out a b w is the linear interpolation a + w * (b - a). w is typically in [0, 1]. out defaults to a fresh allocation.

  # let a = create float32 [| 2 |] [| 1.; 2. |] in
    let b = create float32 [| 2 |] [| 5.; 8. |] in
    lerp a b (scalar float32 0.25)
  - : (float, float32_elt) t = [2, 3.5]
Sourceval lerp_scalar_weight : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

lerp_scalar_weight ?out a b w is like lerp with a scalar weight. out defaults to a fresh allocation.

Sourceval isinf : ?out:(bool, bool_elt) t -> ('a, 'b) t -> (bool, bool_elt) t

isinf ?out t is true where t is positive or negative infinity, false elsewhere. Non-float dtypes always return all false. out defaults to a fresh allocation.

  # create float32 [| 4 |]
      [| 1.; Float.infinity;
         Float.neg_infinity; Float.nan |]
    |> isinf
  - : (bool, bool_elt) t = [false, true, true, false]

See also isnan, isfinite.

Sourceval isnan : ?out:(bool, bool_elt) t -> ('a, 'b) t -> (bool, bool_elt) t

isnan ?out t is true where t is NaN, false elsewhere. Non-float dtypes always return all false. out defaults to a fresh allocation.

See also isinf, isfinite.

Sourceval isfinite : ?out:(bool, bool_elt) t -> ('a, 'b) t -> (bool, bool_elt) t

isfinite ?out t is true where t is neither infinite nor NaN. Non-float dtypes always return all true. out defaults to a fresh allocation.

See also isinf, isnan.

Comparison and logic

Sourceval cmplt : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

cmplt ?out a b is true where a < b, false elsewhere. out defaults to a fresh allocation.

Sourceval less : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

less a b is cmplt.

Sourceval less_s : ?out:(bool, bool_elt) t -> ('a, 'b) t -> 'a -> (bool, bool_elt) t

less_s ?out t s is true where t < s. out defaults to a fresh allocation.

Sourceval cmpne : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

cmpne ?out a b is true where a ≠ b, false elsewhere. out defaults to a fresh allocation.

Sourceval not_equal : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

not_equal a b is cmpne.

Sourceval not_equal_s : ?out:(bool, bool_elt) t -> ('a, 'b) t -> 'a -> (bool, bool_elt) t

not_equal_s ?out t s is true where t ≠ s. out defaults to a fresh allocation.

Sourceval cmpeq : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

cmpeq ?out a b is true where a = b, false elsewhere. out defaults to a fresh allocation.

Sourceval equal : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

equal a b is cmpeq.

Sourceval equal_s : ?out:(bool, bool_elt) t -> ('a, 'b) t -> 'a -> (bool, bool_elt) t

equal_s ?out t s is true where t = s. out defaults to a fresh allocation.

Sourceval cmpgt : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

cmpgt ?out a b is true where a > b, false elsewhere. out defaults to a fresh allocation.

Sourceval greater : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

greater a b is cmpgt.

Sourceval greater_s : ?out:(bool, bool_elt) t -> ('a, 'b) t -> 'a -> (bool, bool_elt) t

greater_s ?out t s is true where t > s. out defaults to a fresh allocation.

Sourceval cmple : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

cmple ?out a b is true where a ≤ b, false elsewhere. out defaults to a fresh allocation.

Sourceval less_equal : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

less_equal a b is cmple.

Sourceval less_equal_s : ?out:(bool, bool_elt) t -> ('a, 'b) t -> 'a -> (bool, bool_elt) t

less_equal_s ?out t s is true where t ≤ s. out defaults to a fresh allocation.

Sourceval cmpge : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

cmpge ?out a b is true where a ≥ b, false elsewhere. out defaults to a fresh allocation.

Sourceval greater_equal : ?out:(bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

greater_equal a b is cmpge.

Sourceval greater_equal_s : ?out:(bool, bool_elt) t -> ('a, 'b) t -> 'a -> (bool, bool_elt) t

greater_equal_s ?out t s is true where t ≥ s. out defaults to a fresh allocation.

Sourceval array_equal : ('a, 'b) t -> ('a, 'b) t -> (bool, bool_elt) t

array_equal a b is a scalar true iff all elements of a and b are equal. Returns false if shapes differ.

  # let a = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    let b = create int32 [| 3 |] [| 1l; 2l; 3l |] in
    array_equal a b |> item []
  - : bool = true
Sourceval maximum : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

maximum ?out a b is the element-wise maximum of a and b. out defaults to a fresh allocation.

Sourceval maximum_s : ?out:('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

maximum_s ?out t s is the element-wise maximum of t and scalar s. out defaults to a fresh allocation.

Sourceval rmaximum_s : ?out:('a, 'b) t -> 'a -> ('a, 'b) t -> ('a, 'b) t

rmaximum_s ?out s t is maximum_s ?out t s.

Sourceval minimum : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

minimum ?out a b is the element-wise minimum of a and b. out defaults to a fresh allocation.

Sourceval minimum_s : ?out:('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t

minimum_s ?out t s is the element-wise minimum of t and scalar s. out defaults to a fresh allocation.

Sourceval rminimum_s : ?out:('a, 'b) t -> 'a -> ('a, 'b) t -> ('a, 'b) t

rminimum_s ?out s t is minimum_s ?out t s.

Sourceval logical_and : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

logical_and ?out a b is the element-wise logical AND. Non-zero is true. out defaults to a fresh allocation.

Sourceval logical_or : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

logical_or ?out a b is the element-wise logical OR. out defaults to a fresh allocation.

Sourceval logical_xor : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

logical_xor ?out a b is the element-wise logical XOR. out defaults to a fresh allocation.

Sourceval logical_not : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

logical_not ?out t is the element-wise logical NOT: non-zero becomes 0, zero becomes 1. out defaults to a fresh allocation.

Sourceval where : ?out:('a, 'b) t -> (bool, bool_elt) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

where ?out cond if_true if_false selects elements from if_true where cond is true and from if_false elsewhere. All three inputs broadcast to a common shape. out defaults to a fresh allocation.

  # let x =
      create float32 [| 4 |] [| -1.; 2.; -3.; 4. |]
    in
    where
      (cmpgt x (scalar float32 0.))
      x (scalar float32 0.)
  - : (float, float32_elt) t = [0, 2, 0, 4]
Sourceval clamp : ?out:('a, 'b) t -> ?min:'a -> ?max:'a -> ('a, 'b) t -> ('a, 'b) t

clamp ?out ?min ?max t clamps elements to [min, max]. Either bound may be omitted. out defaults to a fresh allocation.

See also clip.

Sourceval clip : ?out:('a, 'b) t -> ?min:'a -> ?max:'a -> ('a, 'b) t -> ('a, 'b) t

clip ?out ?min ?max t is clamp.

Bitwise operations

Sourceval bitwise_xor : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

bitwise_xor ?out a b is the element-wise bitwise XOR. out defaults to a fresh allocation.

Sourceval bitwise_or : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

bitwise_or ?out a b is the element-wise bitwise OR. out defaults to a fresh allocation.

Sourceval bitwise_and : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

bitwise_and ?out a b is the element-wise bitwise AND. out defaults to a fresh allocation.

Sourceval bitwise_not : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

bitwise_not ?out t is the element-wise bitwise NOT. out defaults to a fresh allocation.

Sourceval invert : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

invert ?out t is bitwise_not.

Sourceval lshift : ?out:('a, 'b) t -> ('a, 'b) t -> int -> ('a, 'b) t

lshift ?out t n left-shifts each element by n bits. out defaults to a fresh allocation.

Raises Invalid_argument if n is negative or the dtype is not an integer type.

  # create int32 [| 3 |] [| 1l; 2l; 3l |]
    |> Fun.flip lshift 2
  - : (int32, int32_elt) t = [4, 8, 12]

See also rshift.

Sourceval rshift : ?out:('a, 'b) t -> ('a, 'b) t -> int -> ('a, 'b) t

rshift ?out t n right-shifts each element by n bits. out defaults to a fresh allocation.

Raises Invalid_argument if n is negative or the dtype is not an integer type.

See also lshift.

Infix operators

Sourcemodule Infix : sig ... end

Reductions

Sourceval sum : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

sum ?out ?axes ?keepdims t sums elements along axes. When axes is omitted, reduces all axes (returns a scalar). When keepdims is true, reduced axes are kept with size 1. keepdims defaults to false. Negative axes count from the end. out defaults to a fresh allocation.

  # create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |]
    |> sum |> item []
  - : float = 10.
  # create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |]
    |> sum ~axes:[ 0 ]
  - : (float, float32_elt) t = [4, 6]
  # create float32 [| 1; 2 |] [| 1.; 2. |]
    |> sum ~axes:[ 1 ] ~keepdims:true
  - : (float, float32_elt) t = float32 [1; 1] [[3]]
Sourceval max : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

max ?out ?axes ?keepdims t is the maximum along axes. NaN propagates. keepdims defaults to false. out defaults to a fresh allocation.

  # create float32 [| 2; 3 |]
      [| 1.; 2.; 3.; 4.; 5.; 6. |]
    |> max |> item []
  - : float = 6.
Sourceval min : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

min ?out ?axes ?keepdims t is the minimum along axes. NaN propagates. keepdims defaults to false. out defaults to a fresh allocation.

Sourceval prod : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

prod ?out ?axes ?keepdims t is the product along axes. keepdims defaults to false. out defaults to a fresh allocation.

  # create int32 [| 3 |] [| 2l; 3l; 4l |]
    |> prod |> item []
  - : int32 = 24l
Sourceval cumsum : ?axis:int -> ('a, 'b) t -> ('a, 'b) t

cumsum ?axis t is the inclusive cumulative sum along axis. When axis is omitted, operates on the flattened tensor.

See also cumprod.

Sourceval cumprod : ?axis:int -> ('a, 'b) t -> ('a, 'b) t

cumprod ?axis t is the inclusive cumulative product along axis. When axis is omitted, operates on the flattened tensor.

See also cumsum.

Sourceval cummax : ?axis:int -> ('a, 'b) t -> ('a, 'b) t

cummax ?axis t is the inclusive cumulative maximum along axis. NaN propagates for floating-point dtypes. When axis is omitted, operates on the flattened tensor.

See also cummin.

Sourceval cummin : ?axis:int -> ('a, 'b) t -> ('a, 'b) t

cummin ?axis t is the inclusive cumulative minimum along axis. NaN propagates for floating-point dtypes. When axis is omitted, operates on the flattened tensor.

See also cummax.

Sourceval mean : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

mean ?out ?axes ?keepdims t is the arithmetic mean along axes. NaN propagates. keepdims defaults to false. out defaults to a fresh allocation.

  # create float32 [| 4 |] [| 1.; 2.; 3.; 4. |]
    |> mean |> item []
  - : float = 2.5
Sourceval var : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ?ddof:int -> ('a, 'b) t -> ('a, 'b) t

var ?out ?axes ?keepdims ?ddof t is the variance along axes. ddof (delta degrees of freedom) defaults to 0 (population variance); use 1 for sample variance. Computed as E[(X - E[X])²] / (N - ddof). keepdims defaults to false. out defaults to a fresh allocation.

Raises Invalid_argument if ddof >= N.

  # create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |]
    |> var |> item []
  - : float = 2.
  # create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |]
    |> var ~ddof:1 |> item []
  - : float = 2.5

See also std.

Sourceval std : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ?ddof:int -> ('a, 'b) t -> ('a, 'b) t

std ?out ?axes ?keepdims ?ddof t is the standard deviation: sqrt({!var} ~ddof t). ddof defaults to 0. keepdims defaults to false. out defaults to a fresh allocation.

See also var.

Sourceval all : ?out:(bool, bool_elt) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> (bool, bool_elt) t

all ?out ?axes ?keepdims t is true iff every element along axes is non-zero. keepdims defaults to false. out defaults to a fresh allocation.

  # create int32 [| 3 |] [| 1l; 2l; 3l |]
    |> all |> item []
  - : bool = true
  # create int32 [| 3 |] [| 1l; 0l; 3l |]
    |> all |> item []
  - : bool = false

See also any.

Sourceval any : ?out:(bool, bool_elt) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> (bool, bool_elt) t

any ?out ?axes ?keepdims t is true iff at least one element along axes is non-zero. keepdims defaults to false. out defaults to a fresh allocation.

See also all.

Sourceval argmax : ?axis:int -> ?keepdims:bool -> ('a, 'b) t -> (int32, int32_elt) t

argmax ?axis ?keepdims t is the index of the maximum along axis. Returns the first occurrence for ties. When axis is omitted, operates on the flattened tensor. keepdims defaults to false.

Raises Invalid_argument if axis is out of bounds.

  # create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |]
    |> argmax |> item []
  - : int32 = 4l

See also argmin.

Sourceval argmin : ?axis:int -> ?keepdims:bool -> ('a, 'b) t -> (int32, int32_elt) t

argmin ?axis ?keepdims t is the index of the minimum along axis. Returns the first occurrence for ties. When axis is omitted, operates on the flattened tensor. keepdims defaults to false.

Raises Invalid_argument if axis is out of bounds.

See also argmax.

Sorting and searching

Sourceval sort : ?descending:bool -> ?axis:int -> ('a, 'b) t -> ('a, 'b) t * (int32, int32_elt) t

sort ?descending ?axis t sorts elements along axis and returns (sorted, indices) where indices maps sorted positions back to originals. descending defaults to false. axis defaults to -1 (last).

The sort is stable (equal elements preserve their relative order). NaN sorts to the end in ascending order and to the beginning in descending order.

Raises Invalid_argument if axis is out of bounds.

  # create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |]
    |> sort
  - : (int32, int32_elt) t * (int32, int32_elt) t =
  (int32 [5] [1, 1, ..., 4, 5], int32 [5] [1, 3, ..., 2, 4])

See also argsort.

Sourceval argsort : ?descending:bool -> ?axis:int -> ('a, 'b) t -> (int32, int32_elt) t

argsort ?descending ?axis t is snd (sort ?descending ?axis t).

See also sort.

Linear algebra

Products

Sourceval dot : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

dot ?out a b is the generalised dot product. out defaults to a fresh allocation.

Contracts the last axis of a with:

  • the only axis of b when b is 1-D,
  • the second-to-last axis of b otherwise.

Dimension rules:

  • 1-D × 1-D → scalar (inner product).
  • 2-D × 2-D → matrix multiplication.
  • N-D × M-D → contraction; output axes are the non-contracted axes of a followed by those of b.

Note. Unlike matmul, dot does not broadcast batch dimensions—it concatenates them.

Raises Invalid_argument if contraction axes differ in size or either input is 0-D.

  # let a = create float32 [| 2 |] [| 1.; 2. |] in
    let b = create float32 [| 2 |] [| 3.; 4. |] in
    dot a b |> item []
  - : float = 11.
  # dot (ones float32 [| 3; 4; 5 |])
        (ones float32 [| 5; 6 |]) |> shape
  - : int array = [|3; 4; 6|]

See also matmul, vdot, vecdot.

Sourceval matmul : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

matmul ?out a b is the matrix product of a and b with batch broadcasting. out defaults to a fresh allocation; ignored when either input is 1-D.

Dimension rules:

  • 1-D × 1-D → scalar (inner product).
  • 1-D × N-D → a is treated as a row vector.
  • N-D × 1-D → b is treated as a column vector.
  • N-D × M-D → matrix multiply on last two axes; leading axes are broadcast.

Raises Invalid_argument if inputs are 0-D or inner dimensions mismatch.

  # let a =
      create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |]
    in
    let b = create float32 [| 2 |] [| 5.; 6. |] in
    matmul a b
  - : (float, float32_elt) t = [17, 39]
  # matmul (ones float32 [| 1; 3; 4 |])
           (ones float32 [| 5; 4; 2 |]) |> shape
  - : int array = [|5; 3; 2|]

See also dot, multi_dot.

Sourceval diagonal : ?offset:int -> ?axis1:int -> ?axis2:int -> ('a, 'b) t -> ('a, 'b) t

diagonal ?offset ?axis1 ?axis2 t extracts diagonals from 2-D planes defined by axis1 and axis2. offset defaults to 0. axis1 and axis2 default to the last two axes.

Raises Invalid_argument if axis1 = axis2 or either is out of bounds.

See also diag, trace.

Sourceval matrix_transpose : ('a, 'b) t -> ('a, 'b) t

matrix_transpose t swaps the last two axes: […; m; n][…; n; m]. For 1-D tensors, returns t unchanged.

See also transpose.

Sourceval vdot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

vdot a b is the dot product of two vectors. Both inputs are flattened; for complex dtypes, a is conjugated first. Always returns a scalar.

Raises Invalid_argument if the inputs have different numbers of elements.

See also dot, vecdot.

Sourceval vecdot : ?axis:int -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

vecdot ?axis a b is the dot product of a and b along axis with broadcasting. axis defaults to -1.

Raises Invalid_argument if the specified axis dimensions differ.

See also vdot, dot.

Sourceval inner : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

inner a b is the inner product over the last axes of a and b.

Raises Invalid_argument if the last dimensions differ.

See also dot, outer.

Sourceval outer : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

outer ?out a b is the outer product. Inputs are flattened to 1-D; the result has shape [numel a; numel b]. out defaults to a fresh allocation.

See also inner.

Sourceval tensordot : ?axes:(int list * int list) -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

tensordot ?axes a b contracts a and b along the specified axis pairs. axes defaults to contracting the last axis of a with the first axis of b.

Raises Invalid_argument if the contracted axes have different sizes.

Sourceval einsum : string -> ('a, 'b) t array -> ('a, 'b) t

einsum subscripts operands evaluates Einstein summation.

  # let a =
      create float32 [| 2; 3 |]
        [| 1.; 2.; 3.; 4.; 5.; 6. |]
    in
    let b =
      create float32 [| 3; 2 |]
        [| 1.; 2.; 3.; 4.; 5.; 6. |]
    in
    einsum "ij,jk->ik" [| a; b |] |> shape
  - : int array = [|2; 2|]

See also matmul, tensordot.

Sourceval kron : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

kron a b is the Kronecker product. The result has shape [a.shape.(i) * b.shape.(i)] for each i.

Sourceval multi_dot : ('a, 'b) t array -> ('a, 'b) t

multi_dot ts is the chained matrix product of ts, automatically choosing the association order that minimises computation.

Raises Invalid_argument if the array is empty, shapes are incompatible, or dtypes are not floating-point or complex.

See also matmul.

Sourceval matrix_power : ('a, 'b) t -> int -> ('a, 'b) t

matrix_power t n raises square matrix t to integer power n. n = 0 returns the identity; n < 0 uses the inverse.

Raises Invalid_argument if t is not square, the dtype is not floating-point or complex, or n < 0 and t is singular.

Sourceval cross : ?out:('a, 'b) t -> ?axis:int -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

cross ?out ?axis a b is the cross product of 3-element vectors along axis. axis defaults to -1. out defaults to a fresh allocation.

Raises Invalid_argument if the axis dimension is not 3.

Decompositions

Sourceval cholesky : ?upper:bool -> ('a, 'b) t -> ('a, 'b) t

cholesky ?upper a is the Cholesky factor of positive- definite matrix a. When upper is true, returns the upper-triangular factor U such that a = Uᵀ U; otherwise (default) returns the lower-triangular factor L such that a = L Lᵀ.

Raises Invalid_argument if a is not positive-definite or the dtype is not floating-point or complex.

See also solve.

Sourceval qr : ?mode:[ `Complete | `Reduced ] -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t

qr ?mode a is (Q, R) where a = Q R, Q is orthogonal, and R is upper-triangular. mode defaults to `Reduced.

Raises Invalid_argument if the dtype is not floating-point or complex.

See also svd.

Sourceval svd : ?full_matrices:bool -> ('a, 'b) t -> ('a, 'b) t * (float, float64_elt) t * ('a, 'b) t

svd ?full_matrices a is (U, S, Vh) where a = U diag(S) Vh. S contains the singular values in descending order. full_matrices defaults to false (economy decomposition).

Raises Invalid_argument if the dtype is not floating-point or complex.

See also svdvals, qr.

Sourceval svdvals : ('a, 'b) t -> (float, float64_elt) t

svdvals a is the singular values of a in descending order. More efficient than svd when only the values are needed.

Raises Invalid_argument if the dtype is not floating-point or complex.

Eigenvalues and eigenvectors

Sourceval eig : ('a, 'b) t -> (Complex.t, complex64_elt) t * (Complex.t, complex64_elt) t

eig a is (eigenvalues, eigenvectors) of general square matrix a. Results are complex since real matrices may have complex eigenvalues.

Raises Invalid_argument if a is not square or the dtype is not floating-point or complex.

See also eigh, eigvals.

Sourceval eigh : ?uplo:[ `U | `L ] -> ('a, 'b) t -> (float, float64_elt) t * ('a, 'b) t

eigh ?uplo a is (eigenvalues, eigenvectors) of symmetric / Hermitian matrix a in ascending eigenvalue order. uplo defaults to `L. More efficient than eig for symmetric matrices.

Raises Invalid_argument if a is not square or the dtype is not floating-point or complex.

See also eig, eigvalsh.

Sourceval eigvals : ('a, 'b) t -> (Complex.t, complex64_elt) t

eigvals a is the eigenvalues of general square matrix a. More efficient than eig when eigenvectors are not needed.

Raises Invalid_argument if a is not square or the dtype is not floating-point or complex.

See also eig, eigvalsh.

Sourceval eigvalsh : ?uplo:[ `U | `L ] -> ('a, 'b) t -> (float, float64_elt) t

eigvalsh ?uplo a is the eigenvalues of symmetric / Hermitian matrix a in ascending order. uplo defaults to `L.

Raises Invalid_argument if a is not square or the dtype is not floating-point or complex.

See also eigh, eigvals.

Norms and invariants

Sourceval norm : ?ord: [ `Fro | `Nuc | `One | `Two | `Inf | `NegOne | `NegTwo | `NegInf | `P of float ] -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

norm ?ord ?axes ?keepdims t is the matrix or vector norm. ord defaults to Frobenius for matrices, 2-norm for vectors. keepdims defaults to false.

  • `Fro — Frobenius norm.
  • `Nuc — nuclear norm (sum of singular values).
  • `One — max absolute column sum (matrix) or 1-norm (vector).
  • `Two — largest singular value (matrix) or 2-norm (vector).
  • `Inf — max absolute row sum (matrix) or ∞-norm (vector).
  • `P p — p-norm (vectors only).
  • `NegOne, `NegTwo, `NegInf — corresponding minimum norms.

Raises Invalid_argument if ord requires a floating-point or complex dtype.

Sourceval cond : ?p:[ `One | `Two | `Inf | `NegOne | `NegTwo | `NegInf | `Fro ] -> ('a, 'b) t -> ('a, 'b) t

cond ?p a is the condition number of a in the p-norm. p defaults to `Two.

Raises Invalid_argument if the dtype is not floating-point or complex.

Sourceval det : ('a, 'b) t -> ('a, 'b) t

det a is the determinant of square matrix a.

Raises Invalid_argument if a is not square or the dtype is not floating-point or complex.

Sourceval slogdet : ('a, 'b) t -> (float, float32_elt) t * (float, float32_elt) t

slogdet a is (sign, log_abs_det) where det a = sign * exp(log_abs_det). More numerically stable than det for matrices with very large or small determinants.

Raises Invalid_argument if a is not square or the dtype is not floating-point or complex.

Sourceval matrix_rank : ?tol:float -> ?rtol:float -> ?hermitian:bool -> ('a, 'b) t -> int

matrix_rank ?tol ?rtol ?hermitian a is the rank of a, counting singular values above the tolerance. rtol defaults to max(M, N) * ε * σ_max. When hermitian is true (default false), uses a more efficient eigenvalue-based algorithm.

Raises Invalid_argument if the dtype is not floating-point or complex.

Sourceval trace : ?out:('a, 'b) t -> ?offset:int -> ('a, 'b) t -> ('a, 'b) t

trace ?out ?offset t is the sum along the offset-th diagonal. offset defaults to 0. out defaults to a fresh allocation.

Raises Invalid_argument if t has fewer than 2 dimensions.

See also diagonal.

Solving

Sourceval solve : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

solve a b is x such that a @@ x = b.

Raises Invalid_argument if a is singular or the dtype is not floating-point or complex.

See also lstsq, inv.

Sourceval lstsq : ?rcond:float -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t * ('a, 'b) t * int * (float, float64_elt) t

lstsq ?rcond a b is (x, residuals, rank, sv) — the least-squares solution to a @@ x ≈ b. rcond defaults to machine precision.

Raises Invalid_argument if the dtype is not floating-point or complex.

See also solve.

Sourceval inv : ('a, 'b) t -> ('a, 'b) t

inv a is the inverse of square matrix a.

Raises Invalid_argument if a is singular, not square, or the dtype is not floating-point or complex.

See also pinv, solve.

Sourceval pinv : ?rtol:float -> ?hermitian:bool -> ('a, 'b) t -> ('a, 'b) t

pinv ?rtol ?hermitian a is the Moore–Penrose pseudoinverse of a. Handles non-square and singular matrices. hermitian defaults to false.

Raises Invalid_argument if the dtype is not floating-point or complex.

See also inv.

Sourceval tensorsolve : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

tensorsolve ?axes a b solves the tensor equation tensordot a x axes = b for x.

Raises Invalid_argument if shapes are incompatible or the dtype is not floating-point or complex.

Sourceval tensorinv : ?ind:int -> ('a, 'b) t -> ('a, 'b) t

tensorinv ?ind a is the tensor inverse such that tensordot a (tensorinv a) ind is the identity. ind defaults to 2.

Raises Invalid_argument if the result is not square in the specified dimensions or the dtype is not floating-point or complex.

Fourier transforms

Sourcetype fft_norm = [
  1. | `Backward
  2. | `Forward
  3. | `Ortho
]

FFT normalisation mode.

  • `Backward — normalise by 1/n on the inverse (default).
  • `Forward — normalise by 1/n on the forward.
  • `Ortho — normalise by 1/√n on both.
Sourceval fft : ?out:(Complex.t, 'a) t -> ?axis:int -> ?n:int -> ?norm:fft_norm -> (Complex.t, 'a) t -> (Complex.t, 'a) t

fft ?out ?axis ?n ?norm x is the 1-D discrete Fourier transform along axis. axis defaults to -1. n truncates or zero-pads the input. norm defaults to `Backward. out defaults to a fresh allocation.

See also ifft, rfft.

Sourceval ifft : ?out:(Complex.t, 'a) t -> ?axis:int -> ?n:int -> ?norm:fft_norm -> (Complex.t, 'a) t -> (Complex.t, 'a) t

ifft ?out ?axis ?n ?norm x is the inverse of fft. out defaults to a fresh allocation.

See also fft, irfft.

Sourceval fft2 : ?out:(Complex.t, 'a) t -> ?axes:int list -> ?s:int list -> ?norm:fft_norm -> (Complex.t, 'a) t -> (Complex.t, 'a) t

fft2 ?out ?axes ?s ?norm x is the 2-D FFT. axes defaults to the last two. out defaults to a fresh allocation.

Raises Invalid_argument if the input has fewer than 2 dimensions.

See also ifft2, fft.

Sourceval ifft2 : ?out:(Complex.t, 'a) t -> ?axes:int list -> ?s:int list -> ?norm:fft_norm -> (Complex.t, 'a) t -> (Complex.t, 'a) t

ifft2 ?out ?axes ?s ?norm x is the inverse of fft2. out defaults to a fresh allocation.

Sourceval fftn : ?out:(Complex.t, 'a) t -> ?axes:int list -> ?s:int list -> ?norm:fft_norm -> (Complex.t, 'a) t -> (Complex.t, 'a) t

fftn ?out ?axes ?s ?norm x is the N-D FFT. axes defaults to all. out defaults to a fresh allocation.

See also ifftn.

Sourceval ifftn : ?out:(Complex.t, 'a) t -> ?axes:int list -> ?s:int list -> ?norm:fft_norm -> (Complex.t, 'a) t -> (Complex.t, 'a) t

ifftn ?out ?axes ?s ?norm x is the inverse of fftn. out defaults to a fresh allocation.

Sourceval rfft : ?out:(Complex.t, complex64_elt) t -> ?axis:int -> ?n:int -> ?norm:fft_norm -> (float, 'a) t -> (Complex.t, complex64_elt) t

rfft ?out ?axis ?n ?norm x is the 1-D FFT of real input. Returns only the non-redundant positive frequencies; the output size along the transformed axis is n/2 + 1. out defaults to a fresh allocation.

  # create float64 [| 4 |] [| 0.; 1.; 2.; 3. |]
    |> rfft |> shape
  - : int array = [|3|]

See also irfft, fft.

Sourceval irfft : ?out:(float, float64_elt) t -> ?axis:int -> ?n:int -> ?norm:fft_norm -> (Complex.t, 'a) t -> (float, float64_elt) t

irfft ?out ?axis ?n ?norm x is the inverse of rfft, producing real output. Assumes Hermitian symmetry. out defaults to a fresh allocation.

See also rfft.

Sourceval rfft2 : ?out:(Complex.t, complex64_elt) t -> ?axes:int list -> ?s:int list -> ?norm:fft_norm -> (float, 'a) t -> (Complex.t, complex64_elt) t

rfft2 ?out ?axes ?s ?norm x is the 2-D FFT of real input. out defaults to a fresh allocation.

See also irfft2, rfft.

Sourceval irfft2 : ?out:(float, float64_elt) t -> ?axes:int list -> ?s:int list -> ?norm:fft_norm -> (Complex.t, 'a) t -> (float, float64_elt) t

irfft2 ?out ?axes ?s ?norm x is the inverse of rfft2. out defaults to a fresh allocation.

Sourceval rfftn : ?out:(Complex.t, complex64_elt) t -> ?axes:int list -> ?s:int list -> ?norm:fft_norm -> (float, 'a) t -> (Complex.t, complex64_elt) t

rfftn ?out ?axes ?s ?norm x is the N-D FFT of real input. out defaults to a fresh allocation.

See also irfftn, rfft.

Sourceval irfftn : ?out:(float, float64_elt) t -> ?axes:int list -> ?s:int list -> ?norm:fft_norm -> (Complex.t, 'a) t -> (float, float64_elt) t

irfftn ?out ?axes ?s ?norm x is the inverse of rfftn. out defaults to a fresh allocation.

Sourceval hfft : ?axis:int -> ?n:int -> ?norm:fft_norm -> (Complex.t, 'a) t -> (float, float64_elt) t

hfft ?axis ?n ?norm x is the FFT of a signal with Hermitian symmetry, producing real output.

Sourceval ihfft : ?axis:int -> ?n:int -> ?norm:fft_norm -> (float, 'a) t -> (Complex.t, complex64_elt) t

ihfft ?axis ?n ?norm x is the inverse of hfft.

Sourceval fftfreq : ?d:float -> int -> (float, float64_elt) t

fftfreq ?d n is the DFT sample frequencies for window length n and sample spacing d (default 1.0).

  # fftfreq 4
  - : (float, float64_elt) t = [0, 0.25, -0.5, -0.25]

See also rfftfreq.

Sourceval rfftfreq : ?d:float -> int -> (float, float64_elt) t

rfftfreq ?d n is the positive DFT sample frequencies: [0, 1, …, n/2] / (d * n).

See also fftfreq.

Sourceval fftshift : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t

fftshift ?axes t shifts the zero-frequency component to the centre. axes defaults to all.

  # fftfreq 5 |> fftshift
  - : (float, float64_elt) t = float64 [5] [-0.4, -0.2, ..., 0.2, 0.4]

See also ifftshift.

Sourceval ifftshift : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t

ifftshift ?axes t is the inverse of fftshift.

Activation functions

Sourceval relu : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

relu ?out t is max(0, t) element-wise. out defaults to a fresh allocation.

  # create float32 [| 5 |]
      [| -2.; -1.; 0.; 1.; 2. |]
    |> relu
  - : (float, float32_elt) t = float32 [5] [0, 0, ..., 1, 2]
Sourceval sigmoid : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

sigmoid ?out t is 1 / (1 + exp(-t)) element-wise. Output in (0, 1). out defaults to a fresh allocation.

  # sigmoid (scalar float32 0.) |> item []
  - : float = 0.5
Sourceval softmax : ?out:('a, 'b) t -> ?axes:int list -> ?scale:float -> ('a, 'b) t -> ('a, 'b) t

softmax ?out ?axes ?scale t is the softmax normalisation exp(scale * (t - max t)) / Σ exp(scale * (t - max t)). axes defaults to [-1]. scale defaults to 1.0. Output sums to 1 along the specified axes. out defaults to a fresh allocation.

  # create float32 [| 3 |] [| 1.; 2.; 3. |]
    |> softmax |> sum |> item []
  - : float = 1.

See also log_softmax.

Sourceval log_softmax : ?out:('a, 'b) t -> ?axes:int list -> ?scale:float -> ('a, 'b) t -> ('a, 'b) t

log_softmax ?out ?axes ?scale t is the natural logarithm of softmax. Same defaults as softmax. out defaults to a fresh allocation.

See also softmax, logsumexp.

Sourceval logsumexp : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

logsumexp ?out ?axes ?keepdims t is log(Σ exp(t)) computed in a numerically stable way. axes defaults to all. keepdims defaults to false. out defaults to a fresh allocation.

See also logmeanexp, log_softmax.

Sourceval logmeanexp : ?out:('a, 'b) t -> ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t

logmeanexp ?out ?axes ?keepdims t is log(mean(exp(t))): logsumexp minus log N. axes defaults to all. keepdims defaults to false. out defaults to a fresh allocation.

See also logsumexp.

Sourceval standardize : ?out:('a, 'b) t -> ?axes:int list -> ?mean:('a, 'b) t -> ?variance:('a, 'b) t -> ?epsilon:float -> ('a, 'b) t -> ('a, 'b) t

standardize ?out ?axes ?mean ?variance ?epsilon t is (t - mean) / sqrt(variance + epsilon). When mean or variance are omitted, they are computed along axes (default all). epsilon defaults to 1e-5. out defaults to a fresh allocation.

Sourceval erf : ?out:('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

erf ?out t is the error function erf(x) = (2/√π) ∫₀ˣ e^{-u²} du. out defaults to a fresh allocation.

  # erf (scalar float32 0.) |> item []
  - : float = 0.

Sliding windows

Patches

Sourceval extract_patches : kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) t -> ('a, 'b) t

extract_patches ~kernel_size ~stride ~dilation ~padding t extracts sliding windows from the last K spatial dimensions where K = Array.length kernel_size.

Input: [leading…; spatial…]. Output: [leading…; prod(kernel_size); L].

  # arange_f float32 0. 16. 1.
    |> reshape [| 1; 1; 4; 4 |]
    |> extract_patches
         ~kernel_size:[| 2; 2 |]
         ~stride:[| 1; 1 |]
         ~dilation:[| 1; 1 |]
         ~padding:[| (0, 0); (0, 0) |]
    |> shape
  - : int array = [|1; 1; 4; 9|]

See also combine_patches.

Sourceval combine_patches : output_size:int array -> kernel_size:int array -> stride:int array -> dilation:int array -> padding:(int * int) array -> ('a, 'b) t -> ('a, 'b) t

combine_patches ~output_size ~kernel_size ~stride ~dilation ~padding t is the inverse of extract_patches. Overlapping values are summed.

See also extract_patches.

Cross-correlation and convolution

Sourceval correlate : ?padding:[ `Full | `Same | `Valid ] -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

correlate ?padding x kernel is the N-D cross-correlation (no kernel flip). Spatial dimensions K = ndim kernel. Leading dimensions of x beyond K are batch dimensions. padding defaults to `Valid.

See also convolve.

Sourceval convolve : ?padding:[ `Full | `Same | `Valid ] -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t

convolve ?padding x kernel is like correlate but flips the kernel along all spatial axes before correlating.

See also correlate.

Filters

Sourceval maximum_filter : kernel_size:int array -> ?stride:int array -> ('a, 'b) t -> ('a, 'b) t

maximum_filter ~kernel_size ?stride t is the sliding-window maximum over the last K dimensions. stride defaults to kernel_size.

See also minimum_filter, uniform_filter.

Sourceval minimum_filter : kernel_size:int array -> ?stride:int array -> ('a, 'b) t -> ('a, 'b) t

minimum_filter ~kernel_size ?stride t is the sliding-window minimum over the last K dimensions. stride defaults to kernel_size.

See also maximum_filter.

Sourceval uniform_filter : kernel_size:int array -> ?stride:int array -> (float, 'b) t -> (float, 'b) t

uniform_filter ~kernel_size ?stride t is the sliding-window mean over the last K dimensions. stride defaults to kernel_size.

See also maximum_filter, minimum_filter.

Iteration

Sourceval map_item : ('a -> 'a) -> ('a, 'b) t -> ('a, 'b) t

map_item f t applies f to each scalar element of t and returns a fresh tensor of the results.

Sourceval iter_item : ('a -> unit) -> ('a, 'b) t -> unit

iter_item f t applies f to each scalar element of t for its side effects.

Sourceval fold_item : ('a -> 'b -> 'a) -> 'a -> ('b, 'c) t -> 'a

fold_item f init t folds f over the scalar elements of t in row-major order, starting with init.

Sourceval map : (('a, 'b) t -> ('a, 'b) t) -> ('a, 'b) t -> ('a, 'b) t

map f t applies tensor function f to each element of t, presented as a scalar tensor.

See also map_item.

Sourceval iter : (('a, 'b) t -> unit) -> ('a, 'b) t -> unit

iter f t applies tensor function f to each element of t, presented as a scalar tensor.

See also iter_item.

Sourceval fold : ('a -> ('b, 'c) t -> 'a) -> 'a -> ('b, 'c) t -> 'a

fold f init t folds tensor function f over the elements of t, each presented as a scalar tensor.

See also fold_item.

Formatting

Sourceval pp_data : Format.formatter -> ('a, 'b) t -> unit

pp_data fmt t formats the data of t.

Sourceval format_to_string : (Format.formatter -> 'a -> unit) -> 'a -> string

format_to_string pp x is the string produced by pp.

Sourceval print_with_formatter : (Format.formatter -> 'a -> unit) -> 'a -> unit

print_with_formatter pp x prints x to stdout using pp.

Sourceval data_to_string : ('a, 'b) t -> string

data_to_string t is the data of t as a string.

Sourceval print_data : ('a, 'b) t -> unit

print_data t prints the data of t to stdout.

Sourceval pp_dtype : Format.formatter -> ('a, 'b) dtype -> unit

pp_dtype fmt dt formats dt.

Sourceval dtype_to_string : ('a, 'b) dtype -> string

dtype_to_string dt is dt as a string.

Sourceval shape_to_string : int array -> string

shape_to_string s formats s as "[2x3x4]".

Sourceval pp_shape : Format.formatter -> int array -> unit

pp_shape fmt s formats shape s.

Sourceval pp : Format.formatter -> ('a, 'b) t -> unit

pp fmt t formats t for debugging (dtype, shape, and data).

Sourceval print : ('a, 'b) t -> unit

print t prints t to stdout.

Sourceval to_string : ('a, 'b) t -> string

to_string t is t formatted as a string (dtype, shape, and data).