N-dimensional array operations for OCaml.
This module provides NumPy-style tensor operations. Tensors are immutable views over mutable buffers, supporting broadcasting, slicing, and efficient memory layout transformations.
Type System
The type ('a, 'b) t represents a tensor where 'a is the OCaml type of elements and 'b is the bigarray element type. For example, (float, float32_elt) t is a tensor of 32-bit floats.
Broadcasting
Operations automatically broadcast compatible shapes: each dimension must be equal or one of them must be 1. Shape |3; 1; 5| broadcasts with |1; 4; 5| to |3; 4; 5|.
Memory Layout
Tensors can be C-contiguous or strided. Operations return views when possible (O(1)), otherwise copy (O(n)). Use is_contiguous to check layout and contiguous to ensure contiguity.
Type Definitions
('a, 'b) t is a tensor with OCaml type 'a and bigarray type 'b.
Sourcetype ('a, 'b) dtype = ('a, 'b) Nx_core.Dtype.t = | Float16 : (float, float16_elt) dtype| Float32 : (float, float32_elt) dtype| Float64 : (float, float64_elt) dtype| Int8 : (int, int8_elt) dtype| UInt8 : (int, uint8_elt) dtype| Int16 : (int, int16_elt) dtype| UInt16 : (int, uint16_elt) dtype| Int32 : (int32, int32_elt) dtype| Int64 : (int64, int64_elt) dtype| Int : (int, int_elt) dtype| NativeInt : (nativeint, nativeint_elt) dtype| Complex32 : (Complex.t, complex32_elt) dtype| Complex64 : (Complex.t, complex64_elt) dtype| BFloat16 : (float, bfloat16_elt) dtype| Bool : (bool, bool_elt) dtype| Int4 : (int, int4_elt) dtype| UInt4 : (int, uint4_elt) dtype| Float8_e4m3 : (float, float8_e4m3_elt) dtype| Float8_e5m2 : (float, float8_e5m2_elt) dtype| Complex16 : (Complex.t, complex16_elt) dtype| QInt8 : (int, qint8_elt) dtype| QUInt8 : (int, quint8_elt) dtypeData type specification. Links OCaml types to bigarray element types.
Sourcetype index = | I of intSingle index: I 2 selects index 2
| L of int listList of indices: L [0; 2; 5] selects indices 0, 2, 5
| R of int * intRange [start, stop): R (1, 4) selects 1, 2, 3
| Rs of int * int * intRange with step: Rs (0, 10, 2) selects 0, 2, 4, 6, 8
| AAll indices: A selects entire axis
| M of (int, uint8_elt) tBoolean mask: M mask selects where mask is true
| NNew axis: N inserts dimension of size 1
Index specification for tensor slicing
Array Properties
Functions to inspect array dimensions, memory layout, and data access.
data t returns underlying bigarray buffer.
Buffer may contain data beyond tensor bounds for strided views. Direct access requires careful index computation using strides and offset.
Sourceval shape : ('a, 'b) t -> int array shape t returns dimensions. Empty array for scalars.
dtype t returns data type.
Sourceval strides : ('a, 'b) t -> int array strides t returns byte strides for each dimension.
Sourceval stride : int -> ('a, 'b) t -> int stride i t returns byte stride for dimension i.
Sourceval dims : ('a, 'b) t -> int array dims t is synonym for shape.
Sourceval dim : int -> ('a, 'b) t -> int dim i t returns size of dimension i.
ndim t returns number of dimensions.
Sourceval itemsize : ('a, 'b) t -> int itemsize t returns bytes per element.
size t returns total number of elements.
Sourceval numel : ('a, 'b) t -> int numel t is synonym for size.
Sourceval nbytes : ('a, 'b) t -> int nbytes t returns size t * itemsize t.
Sourceval offset : ('a, 'b) t -> int offset t returns element offset in underlying buffer.
Sourceval is_c_contiguous : ('a, 'b) t -> bool is_c_contiguous t returns true if elements are contiguous in C order.
to_bigarray t converts to standard bigarray.
Always returns contiguous copy with same shape. Use for interop with libraries expecting standard bigarrays.
# let t = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
val t : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# shape (to_bigarray t |> of_bigarray)
- : int array = [|2; 3|]
to_bigarray_ext t converts to extended bigarray.
Always returns contiguous copy with same shape. Use for interop with libraries expecting extended bigarrays (e.g., with bfloat16 support).
# let t = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
val t : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# shape (to_bigarray_ext t |> of_bigarray_ext)
- : int array = [|2; 3|]
Sourceval to_array : ('a, 'b) t -> 'a array to_array t converts to OCaml array.
Flattens tensor to 1-D array in row-major (C) order. Always copies.
# let t = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |]
val t : (int32, int32_elt) t = [[1, 2],
[3, 4]]
# to_array t
- : int32 array = [|1l; 2l; 3l; 4l|]
Array Creation
Functions to create and initialize arrays.
Sourceval create : ('a, 'b) dtype -> int array -> 'a array -> ('a, 'b) t create dtype shape data creates tensor from array data.
Length of data must equal product of shape.
# create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
- : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]
Sourceval init : ('a, 'b) dtype -> int array -> (int array -> 'a) -> ('a, 'b) t init dtype shape f creates tensor where element at indices i has value f i.
Function f receives array of indices for each position. Useful for creating position-dependent values.
# init int32 [| 2; 3 |] (fun i -> Int32.of_int (i.(0) + i.(1)))
- : (int32, int32_elt) t = [[0, 1, 2],
[1, 2, 3]]
# init float32 [| 3; 3 |] (fun i -> if i.(0) = i.(1) then 1. else 0.)
- : (float, float32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
empty dtype shape allocates uninitialized tensor.
Sourceval full : ('a, 'b) dtype -> int array -> 'a -> ('a, 'b) t full dtype shape value creates tensor filled with value.
# full float32 [| 2; 3 |] 3.14
- : (float, float32_elt) t = [[3.14, 3.14, 3.14],
[3.14, 3.14, 3.14]]
ones dtype shape creates tensor filled with ones.
zeros dtype shape creates tensor filled with zeros.
scalar dtype value creates scalar tensor containing value.
Sourceval empty_like : ('a, 'b) t -> ('a, 'b) t empty_like t creates uninitialized tensor with same shape and dtype as t.
Sourceval full_like : ('a, 'b) t -> 'a -> ('a, 'b) t full_like t value creates tensor shaped like t filled with value.
Sourceval ones_like : ('a, 'b) t -> ('a, 'b) t ones_like t creates tensor shaped like t filled with ones.
Sourceval zeros_like : ('a, 'b) t -> ('a, 'b) t zeros_like t creates tensor shaped like t filled with zeros.
Sourceval scalar_like : ('a, 'b) t -> 'a -> ('a, 'b) t scalar_like t value creates scalar with same dtype as t.
Sourceval eye : ?m:int -> ?k:int -> ('a, 'b) dtype -> int -> ('a, 'b) t eye ?m ?k dtype n creates matrix with ones on k-th diagonal.
Default m = n (square), k = 0 (main diagonal). Positive k shifts diagonal above main, negative below.
# eye int32 3
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
# eye ~k:1 int32 3
- : (int32, int32_elt) t = [[0, 1, 0],
[0, 0, 1],
[0, 0, 0]]
# eye ~m:2 ~k:(-1) int32 3
- : (int32, int32_elt) t = [[0, 0, 0],
[1, 0, 0]]
identity dtype n creates n×n identity matrix.
Equivalent to eye dtype n. Square matrix with ones on main diagonal, zeros elsewhere.
# identity int32 3
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
Sourceval diag : ?k:int -> ('a, 'b) t -> ('a, 'b) t diag ?k v extracts diagonal or constructs diagonal array.
If v is 1D, returns 2D array with v on the k-th diagonal. If v is 2D, returns 1D array containing the k-th diagonal. Use k > 0 for diagonals above the main diagonal, k < 0 for diagonals below.
# let x = arange int32 0 9 1 |> reshape [|3; 3|]
val x : (int32, int32_elt) t = [[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]
# diag x
- : (int32, int32_elt) t = [0, 4, 8]
# diag ~k:1 x
- : (int32, int32_elt) t = [1, 5]
# let v = create int32 [|3|] [|1l; 2l; 3l|]
val v : (int32, int32_elt) t = [1, 2, 3]
# diag v
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 2, 0],
[0, 0, 3]]
Sourceval arange : ('a, 'b) dtype -> int -> int -> int -> ('a, 'b) t arange dtype start stop step generates values from start to [stop).
Step must be non-zero. Result length is (stop - start) / step rounded toward zero.
# arange int32 0 10 2
- : (int32, int32_elt) t = [0, 2, 4, 6, 8]
# arange int32 5 0 (-1)
- : (int32, int32_elt) t = [5, 4, 3, 2, 1]
Sourceval arange_f : (float, 'a) dtype -> float -> float -> float -> (float, 'a) t arange_f dtype start stop step generates float values from start to [stop).
Like arange but for floating-point ranges. Handles fractional steps. Due to floating-point precision, final value may differ slightly from expected.
# arange_f float32 0. 1. 0.2
- : (float, float32_elt) t = [0, 0.2, 0.4, 0.6, 0.8]
# arange_f float32 1. 0. (-0.25)
- : (float, float32_elt) t = [1, 0.75, 0.5, 0.25]
Sourceval linspace :
('a, 'b) dtype ->
?endpoint:bool ->
float ->
float ->
int ->
('a, 'b) t linspace dtype ?endpoint start stop count generates count evenly spaced values from start to stop.
If endpoint is true (default), stop is included.
# linspace float32 ~endpoint:true 0. 10. 5
- : (float, float32_elt) t = [0, 2.5, 5, 7.5, 10]
# linspace float32 ~endpoint:false 0. 10. 5
- : (float, float32_elt) t = [0, 2, 4, 6, 8]
Sourceval logspace :
(float, 'a) dtype ->
?endpoint:bool ->
?base:float ->
float ->
float ->
int ->
(float, 'a) t logspace dtype ?endpoint ?base start_exp stop_exp count generates values evenly spaced on log scale.
Returns base ** x where x ranges from start_exp to stop_exp. Default base = 10.0.
# logspace float32 0. 2. 3
- : (float, float32_elt) t = [1, 10, 100]
# logspace float32 ~base:2.0 0. 3. 4
- : (float, float32_elt) t = [1, 2, 4, 8]
Sourceval geomspace :
(float, 'a) dtype ->
?endpoint:bool ->
float ->
float ->
int ->
(float, 'a) t geomspace dtype ?endpoint start stop count generates values evenly spaced on geometric (multiplicative) scale.
# geomspace float32 1. 1000. 4
- : (float, float32_elt) t = [1, 10, 100, 1000]
Sourceval meshgrid :
?indexing:[ `xy | `ij ] ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t * ('a, 'b) t meshgrid ?indexing x y creates coordinate grids from 1D arrays.
Returns (X, Y) where X and Y are 2D arrays representing grid coordinates.
`xy (default): Cartesian indexing - X changes along columns, Y changes along rows`ij: Matrix indexing - X changes along rows, Y changes along columns
# let x = linspace float32 0. 2. 3 in
let y = linspace float32 0. 1. 2 in
meshgrid x y
- : (float, float32_elt) t * (float, float32_elt) t =
([[0, 1, 2],
[0, 1, 2]], [[0, 0, 0],
[1, 1, 1]])
Sourceval tril : ?k:int -> ('a, 'b) t -> ('a, 'b) t tril ?k x returns lower triangular part of matrix.
Elements above the k-th diagonal are zeroed.
k = 0 (default): main diagonalk > 0: include k diagonals above maink < 0: exclude |k| diagonals below main
Sourceval triu : ?k:int -> ('a, 'b) t -> ('a, 'b) t triu ?k x returns upper triangular part of matrix.
Elements below the k-th diagonal are zeroed.
k = 0 (default): main diagonalk > 0: exclude k diagonals above maink < 0: include |k| diagonals below main
of_bigarray ba creates tensor from standard bigarray.
Zero-copy when bigarray is contiguous. Creates view sharing same memory. Modifications to either affect both.
# let t = zeros float32 [| 2; 3 |] in
t
- : (float, float32_elt) t = [[0, 0, 0],
[0, 0, 0]]
of_bigarray_ext ba creates tensor from extended bigarray.
Zero-copy when bigarray is contiguous. Creates view sharing same memory. Modifications to either affect both. Supports extended types like bfloat16.
# let ba = Bigarray_ext.Genarray.create Bigarray_ext.bfloat16
Bigarray_ext.c_layout [| 2; 3 |] in
let t = of_bigarray_ext ba in
shape t
- : int array = [|2; 3|]
Random Number Generation
Functions to generate arrays with random values.
Sourceval rand : ('a, 'b) dtype -> ?seed:int -> int array -> ('a, 'b) t rand dtype ?seed shape generates uniform random values in [0, 1).
Only supports float dtypes. Same seed produces same sequence.
Sourceval randn : ('a, 'b) dtype -> ?seed:int -> int array -> ('a, 'b) t randn dtype ?seed shape generates standard normal random values.
Mean 0, variance 1. Uses Box-Muller transform for efficiency. Only supports float dtypes. Same seed produces same sequence.
Sourceval randint :
('a, 'b) dtype ->
?seed:int ->
?high:int ->
int array ->
int ->
('a, 'b) t randint dtype ?seed ?high shape low generates integers in [low, high).
Uniform distribution over range. Default high = 10. Note: high is exclusive (NumPy convention).
Shape Manipulation
Functions to reshape, transpose, and rearrange arrays.
Sourceval reshape : int array -> ('a, 'b) t -> ('a, 'b) t reshape shape t returns view with new shape.
At most one dimension can be -1 (inferred from total elements). Product of dimensions must match total elements. Returns view when possible (O(1)), copies if tensor is not contiguous and cannot be viewed.
# let t = create int32 [|2; 3|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
reshape [|6|] t
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let t = create int32 [|6|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
reshape [|3; -1|] t
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]
Sourceval broadcast_to : int array -> ('a, 'b) t -> ('a, 'b) t broadcast_to shape t broadcasts tensor to target shape.
Shapes must be broadcast-compatible: dimensions align from right, each must be equal or source must be 1. Returns view (no copy) with zero strides for broadcast dimensions.
# let t = create int32 [|1; 3|] [|1l; 2l; 3l|] in
broadcast_to [|3; 3|] t
- : (int32, int32_elt) t = [[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]
# let t = ones float32 [|3; 1|] in
shape (broadcast_to [|2; 3; 4|] t)
- : int array = [|2; 3; 4|]
Sourceval broadcasted :
?reverse:bool ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t * ('a, 'b) t broadcasted ?reverse t1 t2 broadcasts tensors to common shape.
Returns views of both tensors broadcast to compatible shape. If reverse is true, returns (t2', t1') instead of (t1', t2'). Useful before element-wise operations.
# let t1 = ones float32 [|3;1|] in
let t2 = ones float32 [|1;5|] in
let t1', t2' = broadcasted t1 t2 in
shape t1', shape t2'
- : int array * int array = ([|3; 5|], [|3; 5|])
Sourceval expand : int array -> ('a, 'b) t -> ('a, 'b) t expand shape t broadcasts tensor where -1 keeps original dimension.
Like broadcast_to but -1 preserves existing dimension size. Adds dimensions on left if needed.
# let t = ones float32 [|1; 4; 1|] in
shape (expand [|3; -1; 5|] t)
- : int array = [|3; 4; 5|]
# let t = ones float32 [|5; 5|] in
shape (expand [|-1; -1|] t)
- : int array = [|5; 5|]
Sourceval flatten : ?start_dim:int -> ?end_dim:int -> ('a, 'b) t -> ('a, 'b) t flatten ?start_dim ?end_dim t collapses dimensions into single dimension.
Default start_dim = 0, end_dim = -1 (last). Negative indices count from end. Dimensions start_dim through end_dim inclusive are flattened.
# flatten (zeros float32 [| 2; 3; 4 |]) |> shape
- : int array = [|24|]
# flatten ~start_dim:1 ~end_dim:2 (zeros float32 [| 2; 3; 4; 5 |]) |> shape
- : int array = [|2; 12; 5|]
Sourceval unflatten : int -> int array -> ('a, 'b) t -> ('a, 'b) t unflatten dim sizes t expands dimension dim into multiple dimensions.
Product of sizes must equal size of dimension dim. At most one dimension can be -1 (inferred). Inverse of flatten.
# unflatten 1 [| 3; 4 |] (zeros float32 [| 2; 12; 5 |]) |> shape
- : int array = [|2; 3; 4; 5|]
# unflatten 0 [| -1; 2 |] (ones float32 [| 6; 5 |]) |> shape
- : int array = [|3; 2; 5|]
Sourceval ravel : ('a, 'b) t -> ('a, 'b) t ravel t returns contiguous 1-D view.
Equivalent to flatten t but always returns contiguous result. Use when you need both flattening and contiguity.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
ravel x
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let t = transpose (ones float32 [| 3; 4 |]) in
is_c_contiguous t
- : bool = false
# let t_ravel = ravel t in
is_c_contiguous t_ravel
- : bool = true
Sourceval squeeze : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t squeeze ?axes t removes dimensions of size 1.
If axes specified, only removes those dimensions. Negative indices count from end. Returns view when possible.
# squeeze (ones float32 [| 1; 3; 1; 4 |]) |> shape
- : int array = [|3; 4|]
# squeeze ~axes:[ 0; 2 ] (ones float32 [| 1; 3; 1; 4 |]) |> shape
- : int array = [|3; 4|]
# squeeze ~axes:[ -1 ] (ones float32 [| 3; 4; 1 |]) |> shape
- : int array = [|3; 4|]
Sourceval unsqueeze : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t unsqueeze ?axes t inserts dimensions of size 1 at specified positions.
Axes refer to positions in result tensor. Must be in range 0, ndim.
# unsqueeze ~axes:[ 0; 2 ] (create float32 [| 3 |] [| 1.; 2.; 3. |]) |> shape
- : int array = [|1; 3; 1|]
# unsqueeze ~axes:[ 1 ] (create float32 [| 2 |] [| 5.; 6. |]) |> shape
- : int array = [|2; 1|]
Sourceval squeeze_axis : int -> ('a, 'b) t -> ('a, 'b) t squeeze_axis axis t removes dimension axis if size is 1.
Sourceval unsqueeze_axis : int -> ('a, 'b) t -> ('a, 'b) t unsqueeze_axis axis t inserts dimension of size 1 at axis.
Sourceval expand_dims : int list -> ('a, 'b) t -> ('a, 'b) t Sourceval transpose : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t transpose ?axes t permutes dimensions.
Default reverses all dimensions. axes must be permutation of 0..ndim-1. Returns view (no copy) with adjusted strides.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
transpose x
- : (int32, int32_elt) t = [[1, 4],
[2, 5],
[3, 6]]
# transpose ~axes:[ 2; 0; 1 ] (zeros float32 [| 2; 3; 4 |]) |> shape
- : int array = [|4; 2; 3|]
# let id = transpose ~axes:[ 1; 0 ] in
id == transpose
- : bool = false
Sourceval flip : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t flip ?axes t reverses order along specified dimensions.
Default flips all dimensions.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
flip x
- : (int32, int32_elt) t = [[6, 5, 4],
[3, 2, 1]]
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
flip ~axes:[ 1 ] x
- : (int32, int32_elt) t = [[3, 2, 1],
[6, 5, 4]]
Sourceval as_strided :
int array ->
int array ->
offset:int ->
('a, 'b) t ->
('a, 'b) t as_strided shape strides ~offset t creates a strided view of the tensor with custom shape, strides (in elements), and offset (in elements).
This is a low-level operation that allows arbitrary memory layouts. Backends that support arbitrary strides (e.g., native) implement this as zero-copy. Other backends may need to materialize the data.
Warning: This function can create views that access out-of-bounds memory if used incorrectly. Use with caution.
# let x = create float32 [| 8 |] [| 0.; 1.; 2.; 3.; 4.; 5.; 6.; 7. |] in
as_strided [| 3; 3 |] [| 2; 1 |] ~offset:0 x
- : (float, float32_elt) t = [[0, 1, 2],
[2, 3, 4],
[4, 5, 6]]
Sourceval moveaxis : int -> int -> ('a, 'b) t -> ('a, 'b) t moveaxis src dst t moves dimension from src to dst.
Sourceval swapaxes : int -> int -> ('a, 'b) t -> ('a, 'b) t swapaxes axis1 axis2 t exchanges two dimensions.
Sourceval roll : ?axis:int -> int -> ('a, 'b) t -> ('a, 'b) t roll ?axis shift t shifts elements along axis.
Elements shifted beyond last position wrap to beginning. If axis not specified, shifts flattened tensor. Negative shift rolls backward.
# let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
roll 2 x
- : (int32, int32_elt) t = [4, 5, 1, 2, 3]
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
roll ~axis:1 1 x
- : (int32, int32_elt) t = [[3, 1, 2],
[6, 4, 5]]
# let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
roll ~axis:0 (-1) x
- : (int32, int32_elt) t = [[3, 4],
[1, 2]]
Sourceval pad : (int * int) array -> 'a -> ('a, 'b) t -> ('a, 'b) t pad padding value t pads tensor with value.
padding specifies (before, after) for each dimension. Length must match tensor dimensions. Negative padding not allowed.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
pad [| (1, 1); (1, 1) |] 0. x |> shape
- : int array = [|4; 4|]
Sourceval shrink : (int * int) array -> ('a, 'b) t -> ('a, 'b) t shrink ranges t extracts slice from start to stop (exclusive) for each dimension.
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
shrink [| (1, 3); (0, 2) |] x
- : (int32, int32_elt) t = [[4, 5],
[7, 8]]
Sourceval tile : int array -> ('a, 'b) t -> ('a, 'b) t tile reps t constructs tensor by repeating t.
reps specifies repetitions per dimension. If longer than ndim, prepends dimensions. Zero repetitions create empty tensor.
# let x = create int32 [| 1; 2 |] [| 1l; 2l |] in
tile [| 2; 3 |] x
- : (int32, int32_elt) t = [[1, 2, 1, 2, 1, 2],
[1, 2, 1, 2, 1, 2]]
# let x = create int32 [| 2 |] [| 1l; 2l |] in
tile [| 2; 1; 3 |] x |> shape
- : int array = [|2; 1; 6|]
Sourceval repeat : ?axis:int -> int -> ('a, 'b) t -> ('a, 'b) t repeat ?axis count t repeats elements count times.
If axis not specified, repeats flattened tensor.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
repeat 2 x
- : (int32, int32_elt) t = [1, 1, 2, 2, 3, 3]
# let x = create int32 [| 1; 2 |] [| 1l; 2l |] in
repeat ~axis:0 3 x
- : (int32, int32_elt) t = [[1, 2],
[1, 2],
[1, 2]]
Array Combination and Splitting
Functions to join and split arrays.
Sourceval concatenate : ?axis:int -> ('a, 'b) t list -> ('a, 'b) t concatenate ?axis ts joins tensors along existing axis.
All tensors must have same shape except on concatenation axis. If axis not specified, flattens all tensors then concatenates. Returns contiguous result.
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 1; 2 |] [| 5l; 6l |] in
concatenate ~axis:0 [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 1; 2 |] [| 5l; 6l |] in
concatenate [x1; x2]
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
Sourceval stack : ?axis:int -> ('a, 'b) t list -> ('a, 'b) t stack ?axis ts joins tensors along new axis.
All tensors must have identical shape. Result rank is input rank + 1. Default axis=0. Negative axis counts from end of result shape.
# let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
stack [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4]]
# let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
stack ~axis:1 [x1; x2]
- : (int32, int32_elt) t = [[1, 3],
[2, 4]]
# stack ~axis:(-1) [ones float32 [| 2; 3 |]; zeros float32 [| 2; 3 |]] |> shape
- : int array = [|2; 3; 2|]
Sourceval vstack : ('a, 'b) t list -> ('a, 'b) t vstack ts stacks tensors vertically (row-wise).
1-D tensors are treated as row vectors (shape 1;n). Higher-D tensors concatenate along axis 0. All tensors must have same shape except possibly first dimension.
# let x1 = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let x2 = create int32 [| 3 |] [| 4l; 5l; 6l |] in
vstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# let x1 = create int32 [| 1; 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2; 2 |] [| 3l; 4l; 5l; 6l |] in
vstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]
Sourceval hstack : ('a, 'b) t list -> ('a, 'b) t hstack ts stacks tensors horizontally (column-wise).
1-D tensors concatenate directly. Higher-D tensors concatenate along axis 1. For 1-D arrays of different lengths, use vstack to make 2-D first.
# let x1 = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let x2 = create int32 [| 3 |] [| 4l; 5l; 6l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let x1 = create int32 [| 2; 1 |] [| 1l; 2l |] in
let x2 = create int32 [| 2; 1 |] [| 3l; 4l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [[1, 3],
[2, 4]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 2; 1 |] [| 5l; 6l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2, 5],
[3, 4, 6]]
Sourceval dstack : ('a, 'b) t list -> ('a, 'b) t dstack ts stacks tensors depth-wise (along third axis).
Tensors are reshaped to at least 3-D before concatenation:
- 1-D shape
n → 1;n;1 - 2-D shape
m;n → m;n;1 - 3-D+ unchanged
# let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
dstack [x1; x2]
- : (int32, int32_elt) t = [[[1, 3],
[2, 4]]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 2; 2 |] [| 5l; 6l; 7l; 8l |] in
dstack [x1; x2]
- : (int32, int32_elt) t = [[[1, 5],
[2, 6]],
[[3, 7],
[4, 8]]]
Sourceval broadcast_arrays : ('a, 'b) t list -> ('a, 'b) t list broadcast_arrays ts broadcasts all tensors to common shape.
Finds the common broadcast shape and returns list of views with that shape. Broadcasting rules: dimensions align right, each must be 1 or equal.
# let x1 = ones float32 [| 3; 1 |] in
let x2 = ones float32 [| 1; 5 |] in
broadcast_arrays [x1; x2] |> List.map shape
- : int array list = [[|3; 5|]; [|3; 5|]]
# let x1 = scalar float32 5. in
let x2 = ones float32 [| 2; 3; 4 |] in
broadcast_arrays [x1; x2] |> List.map shape
- : int array list = [[|2; 3; 4|]; [|2; 3; 4|]]
Sourceval array_split :
axis:int ->
[< `Count of int | `Indices of int list ] ->
('a, 'b) t ->
('a, 'b) t list array_split ~axis sections t splits tensor into multiple parts.
`Count n divides into n parts as evenly as possible. Extra elements go to first parts. `Indices [i1;i2;...] splits at indices creating start:i1, i1:i2, i2:end.
# let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
array_split ~axis:0 (`Count 3) x
- : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5]]
# let x = create int32 [| 6 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
array_split ~axis:0 (`Indices [ 2; 4 ]) x
- : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5, 6]]
Sourceval split : axis:int -> int -> ('a, 'b) t -> ('a, 'b) t list split ~axis sections t splits into equal parts.
# let x = create int32 [| 4; 2 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l |] in
split ~axis:0 2 x
- : (int32, int32_elt) t list = [[[1, 2],
[3, 4]]; [[5, 6],
[7, 8]]]
Type Conversion and Copying
Functions to convert between types and create copies.
cast dtype t converts elements to new dtype.
Returns copy with same values in new type.
# let x = create float32 [| 3 |] [| 1.5; 2.7; 3.1 |] in
cast int32 x
- : (int32, int32_elt) t = [1, 2, 3]
astype dtype t is synonym for cast.
Sourceval contiguous : ('a, 'b) t -> ('a, 'b) t contiguous t returns C-contiguous tensor.
Returns t unchanged if already contiguous (O(1)), otherwise creates contiguous copy (O(n)). Use before operations requiring direct memory access.
# let t = transpose (ones float32 [| 3; 4 |]) in
is_c_contiguous (contiguous t)
- : bool = true
Sourceval copy : ('a, 'b) t -> ('a, 'b) t copy t returns deep copy.
Always allocates new memory and copies data. Result is contiguous.
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
let y = copy x in
set_item [ 0 ] 999. y;
x, y
- : (float, float32_elt) t * (float, float32_elt) t =
([1, 2, 3], [999, 2, 3])
Sourceval blit : ('a, 'b) t -> ('a, 'b) t -> unit blit src dst copies src into dst.
Shapes must match exactly. Handles broadcasting internally. Modifies dst in-place.
let dst = zeros float32 [| 3; 3 |] in
blit (ones float32 [| 3; 3 |]) dst
(* dst now contains all 1s *)
Sourceval fill : 'a -> ('a, 'b) t -> ('a, 'b) t fill value t sets all elements to value.
Modifies t in-place and returns it for chaining.
# let x = zeros float32 [| 2; 3 |] in
let y = fill 5. x in
y == x
- : bool = true
Element Access and Slicing
Functions to access and modify array elements.
Sourceval get : int list -> ('a, 'b) t -> ('a, 'b) t get indices t returns subtensor at indices.
Indexes from outermost dimension. Returns scalar tensor if all dimensions indexed, otherwise returns view of remaining dimensions.
# let x = create int32 [| 2; 2; 2 |] [| 0l; 1l; 2l; 3l; 4l; 5l; 6l; 7l |] in
get [ 1; 1; 1 ] x
- : (int32, int32_elt) t = 7
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
get [ 1 ] x
- : (int32, int32_elt) t = [4, 5, 6]
Sourceval set : int list -> ('a, 'b) t -> ('a, 'b) t -> unit set indices t value assigns value at indices.
slice specs t extracts subtensor using advanced indexing.
Each element in specs corresponds to an axis from left to right:
I i: single index (reduces dimension; negative from end)L [i1;i2;...]: fancy indexing with list of indicesR (start, stop): range [start, stop) with step 1Rs (start, stop, step): range with stepA: full axis (default for missing specs)M mask: boolean mask (shape must match axis)N: insert new dimension of size 1
Missing specs default to A. Returns view when possible.
# let x = create int32 [| 2; 4 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l |] in
slice [ I 1 ] x
- : (int32, int32_elt) t = [5, 6, 7, 8]
# let x = create int32 [| 5 |] [| 0l; 1l; 2l; 3l; 4l |] in
slice [ R (1, 3) ] x
- : (int32, int32_elt) t = [1, 2]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
slice [ R (0, 2); L [0; 2] ] x
- : (int32, int32_elt) t = [[1, 3],
[4, 6]]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
slice [ A; N ] x (* Add new axis at position 1 *)
- : (int32, int32_elt) t = [[[1, 2, 3]],
[[4, 5, 6]],
[[7, 8, 9]]]
Sourceval set_slice : index list -> ('a, 'b) t -> ('a, 'b) t -> unit set_slice specs t value assigns value to indexed region.
Like slice but modifies t in-place. Value is broadcast if needed.
Sourceval item : int list -> ('a, 'b) t -> 'a item indices t returns scalar value at indices.
Must provide indices for all dimensions.
Sourceval set_item : int list -> 'a -> ('a, 'b) t -> unit set_item indices value t sets scalar value at indices.
Must provide indices for all dimensions. Modifies tensor in-place.
Sourceval take :
?axis:int ->
?mode:[ `raise | `wrap | `clip ] ->
(int32, int32_elt) t ->
('a, 'b) t ->
('a, 'b) t take ?axis ?mode indices t takes elements from t using indices.
Equivalent to tindices in NumPy along the specified axis. If axis is None, flattens t first. indices is an integer tensor of indices to take. mode handles out-of-bounds indices: `raise (default), `wrap (modulo), `clip (clamp to bounds).
Returns a new tensor with shape based on indices and t's shape.
# let x = create int32 [| 5 |] [| 0l; 1l; 2l; 3l; 4l |] in
take (create int32 [| 3 |] [| 1l; 3l; 0l |]) x
- : (int32, int32_elt) t = [1, 3, 0]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
take ~axis:1 (create int32 [| 3 |] [| 0l; 2l; 1l |]) x
- : (int32, int32_elt) t = [[1, 3, 2],
[4, 6, 5],
[7, 9, 8]]
# let x = create int32 [| 5 |] [| 0l; 1l; 2l; 3l; 4l |] in
take ~mode:`clip (create int32 [| 2 |] [| -1l; 5l |]) x (* Clamps to [0,4] *)
- : (int32, int32_elt) t = [0, 4]
take_along_axis ~axis indices t takes values along the specified axis using indices.
Equivalent to NumPy's take_along_axis. indices must have the same shape as t except along the specified axis, where it matches the output size. Useful for gathering from argmax/argmin results.
# let x = create float32 [| 2; 3 |] [| 4.; 1.; 2.; 3.; 5.; 6. |] in
let indices = create int32 [| 2; 1 |] [| 1l; 0l |] in (* Per row indices *)
take_along_axis ~axis:1 indices x
- : (float, float32_elt) t = [[1],
[3]]
# let x = create float32 [| 3; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.; 9. |] in
let indices = expand_dims [ 0 ] (argmax ~axis:0 x) in (* Shape [1, 3] *)
take_along_axis ~axis:0 indices x (* Max per column *)
- : (float, float32_elt) t = [[7, 8, 9]]
Sourceval put :
?axis:int ->
indices:(int32, int32_elt) t ->
values:('a, 'b) t ->
?mode:[ `raise | `wrap | `clip ] ->
('a, 'b) t ->
unit put ?axis ~indices ~values ?mode t sets elements in t at positions specified by indices to values.
Equivalent to NumPy's put (in-place version of take for setting). If axis is None, flattens t first. indices is an integer tensor of positions to set. values must match the number of indices (broadcasted if needed). mode handles out-of-bounds indices: `raise (default), `wrap (modulo), `clip (clamp).
Modifies t in-place.
# let x = zeros int32 [| 5 |] in
put ~indices:(create int32 [| 3 |] [| 1l; 3l; 0l |])
~values:(create int32 [| 3 |] [| 10l; 20l; 30l |]) x;
x
- : (int32, int32_elt) t = [30, 10, 0, 20, 0]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
put ~axis:1 ~indices:(create int32 [| 3; 1 |] [| 0l; 2l; 1l |])
~values:(create int32 [| 3; 1 |] [| 10l; 20l; 30l |]) x;
x
- : (int32, int32_elt) t = [[10, 2, 3],
[4, 5, 20],
[7, 30, 9]]
# let y = zeros int32 [| 5 |] in
put ~mode:`clip ~indices:(create int32 [| 2 |] [| -1l; 5l |])
~values:(create int32 [| 2 |] [| 99l; 99l |]) y;
y (* Clamps to [0,4] *)
- : (int32, int32_elt) t = [99, 0, 0, 0, 99]
Sourceval put_along_axis :
axis:int ->
indices:(int32, int32_elt) t ->
values:('a, 'b) t ->
('a, 'b) t ->
unit put_along_axis ~axis ~indices ~values t sets values along the specified axis using indices.
Equivalent to NumPy's put_along_axis. indices must have the same shape as t except along the axis (where it matches values' size along that axis). values is broadcasted to match the selection shape. Useful for scattering to argmax/argmin positions.
Modifies t in-place.
# let x = zeros float32 [| 2; 3 |] in
let indices = create int32 [| 2; 1 |] [| 1l; 0l |] in (* Per row positions *)
put_along_axis ~axis:1 ~indices ~values:(create float32 [| 2; 1 |] [| 10.; 20. |]) x;
x
- : (float, float32_elt) t = [[0, 10, 0],
[20, 0, 0]]
# let x = create float32 [| 3; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.; 9. |] in
let indices = expand_dims [ 0 ] (argmax ~axis:0 x) in (* Shape [1, 3] *)
put_along_axis ~axis:0 ~indices ~values:(ones float32 [| 1; 3 |]) x;
x (* Set max per column to 1 *)
- : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6],
[1, 1, 1]]
Sourceval compress :
?axis:int ->
condition:(int, uint8_elt) t ->
('a, 'b) t ->
('a, 'b) t compress ?axis condition t selects elements where condition is true.
Equivalent to NumPy's compress. condition is a 1D boolean array. If axis is None, flattens t first. Otherwise, compresses along the specified axis (condition length must match t's dim along axis).
Returns a new tensor with reduced size along the axis/flattened.
# let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
compress ~condition:(create uint8 [| 5 |] [| 1; 0; 1; 0; 1 |]) x
- : (int32, int32_elt) t = [1, 3, 5]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
compress ~axis:0 ~condition:(create uint8 [| 3 |] [| 0; 1; 1 |]) x
- : (int32, int32_elt) t = [[4, 5, 6],
[7, 8, 9]]
extract condition t flattens and selects elements where condition is true.
Equivalent to NumPy's extract (1D compress after flatten). condition must have the same shape and size as t (element-wise).
Returns a 1D tensor with selected elements.
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
extract ~condition:(greater_s x 5l) x
- : (int32, int32_elt) t = [6, 7, 8, 9]
nonzero t returns indices of non-zero elements.
Equivalent to NumPy's nonzero. Treats non-zero as true for bool tensors. Returns an array of 1D tensors, one per dimension, with coordinates of non-zeros.
For example, for a 2D tensor, returns | rows; cols | where rowsi, colsi is the position of the i-th non-zero.
# let x = create int32 [| 3; 3 |] [| 0l; 1l; 0l; 2l; 0l; 3l; 0l; 0l; 4l |] in
let indices = nonzero x in
indices.(0), indices.(1)
- : (int32, int32_elt) t * (int32, int32_elt) t =
([0, 1, 1, 2], [1, 0, 2, 2])
argwhere t returns indices of non-zero elements as a 2D tensor.
Equivalent to NumPy's argwhere. Each row is a coordinate dim0; dim1; ... of a non-zero element. Shape is num_nonzeros; ndim.
# let x = create int32 [| 3; 3 |] [| 0l; 1l; 0l; 2l; 0l; 3l; 0l; 0l; 4l |] in
argwhere x
- : (int32, int32_elt) t = [[0, 1],
[1, 0],
[1, 2],
[2, 2]]
Basic Arithmetic Operations
Element-wise arithmetic operations and their variants.
Sourceval add : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t add t1 t2 computes element-wise sum with broadcasting.
Sourceval add_s : ('a, 'b) t -> 'a -> ('a, 'b) t add_s t scalar adds scalar to each element.
Sourceval iadd : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t iadd target value adds value to target in-place.
Returns modified target.
Sourceval radd_s : 'a -> ('a, 'b) t -> ('a, 'b) t radd_s scalar t is add_s t scalar.
Sourceval iadd_s : ('a, 'b) t -> 'a -> ('a, 'b) t iadd_s t scalar adds scalar to t in-place.
Sourceval sub : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t sub t1 t2 computes element-wise difference with broadcasting.
Sourceval sub_s : ('a, 'b) t -> 'a -> ('a, 'b) t sub_s t scalar subtracts scalar from each element.
Sourceval rsub_s : 'a -> ('a, 'b) t -> ('a, 'b) t rsub_s scalar t computes scalar - t.
Sourceval isub : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t isub target value subtracts value from target in-place.
Sourceval isub_s : ('a, 'b) t -> 'a -> ('a, 'b) t isub_s t scalar subtracts scalar from t in-place.
Sourceval mul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t mul t1 t2 computes element-wise product with broadcasting.
Sourceval mul_s : ('a, 'b) t -> 'a -> ('a, 'b) t mul_s t scalar multiplies each element by scalar.
Sourceval rmul_s : 'a -> ('a, 'b) t -> ('a, 'b) t rmul_s scalar t is mul_s t scalar.
Sourceval imul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t imul target value multiplies target by value in-place.
Sourceval imul_s : ('a, 'b) t -> 'a -> ('a, 'b) t imul_s t scalar multiplies t by scalar in-place.
Sourceval div : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t div t1 t2 computes element-wise division.
True division for floats (result is float). Integer division for integers (truncates toward zero). Complex division follows standard rules.
# let x = create float32 [| 3 |] [| 7.; 8.; 9. |] in
let y = create float32 [| 3 |] [| 2.; 2.; 2. |] in
div x y
- : (float, float32_elt) t = [3.5, 4, 4.5]
# let x = create int32 [| 3 |] [| 7l; 8l; 9l |] in
let y = create int32 [| 3 |] [| 2l; 2l; 2l |] in
div x y
- : (int32, int32_elt) t = [3, 4, 4]
# let x = create int32 [| 2 |] [| -7l; 8l |] in
let y = create int32 [| 2 |] [| 2l; 2l |] in
div x y
- : (int32, int32_elt) t = [-3, 4]
Sourceval div_s : ('a, 'b) t -> 'a -> ('a, 'b) t div_s t scalar divides each element by scalar.
Sourceval rdiv_s : 'a -> ('a, 'b) t -> ('a, 'b) t rdiv_s scalar t computes scalar / t.
Sourceval idiv : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t idiv target value divides target by value in-place.
Sourceval idiv_s : ('a, 'b) t -> 'a -> ('a, 'b) t idiv_s t scalar divides t by scalar in-place.
Sourceval pow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t pow base exponent computes element-wise power.
Sourceval pow_s : ('a, 'b) t -> 'a -> ('a, 'b) t pow_s t scalar raises each element to scalar power.
Sourceval rpow_s : 'a -> ('a, 'b) t -> ('a, 'b) t rpow_s scalar t computes scalar ** t.
Sourceval ipow : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t ipow target exponent raises target to exponent in-place.
Sourceval ipow_s : ('a, 'b) t -> 'a -> ('a, 'b) t ipow_s t scalar raises t to scalar power in-place.
Sourceval mod_ : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t mod_ t1 t2 computes element-wise modulo.
Sourceval mod_s : ('a, 'b) t -> 'a -> ('a, 'b) t mod_s t scalar computes modulo scalar for each element.
Sourceval rmod_s : 'a -> ('a, 'b) t -> ('a, 'b) t rmod_s scalar t computes scalar mod t.
Sourceval imod : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t imod target divisor computes modulo in-place.
Sourceval imod_s : ('a, 'b) t -> 'a -> ('a, 'b) t imod_s t scalar computes modulo scalar in-place.
Sourceval neg : ('a, 'b) t -> ('a, 'b) t neg t negates all elements.
Mathematical Functions
Unary mathematical operations and special functions.
Sourceval abs : ('a, 'b) t -> ('a, 'b) t abs t computes absolute value.
Sourceval sign : ('a, 'b) t -> ('a, 'b) t sign t returns -1, 0, or 1 based on sign.
For unsigned types, returns 1 for all non-zero values, 0 for zero.
# let x = create float32 [| 3 |] [| -2.; 0.; 3.5 |] in
sign x
- : (float, float32_elt) t = [-1, 0, 1]
Sourceval square : ('a, 'b) t -> ('a, 'b) t square t computes element-wise square.
Sourceval sqrt : ('a, 'b) t -> ('a, 'b) t sqrt t computes element-wise square root.
Sourceval rsqrt : ('a, 'b) t -> ('a, 'b) t rsqrt t computes reciprocal square root.
Sourceval recip : ('a, 'b) t -> ('a, 'b) t recip t computes element-wise reciprocal.
Sourceval log : (float, 'a) t -> (float, 'a) t log t computes natural logarithm.
Sourceval log2 : ('a, 'b) t -> ('a, 'b) t log2 t computes base-2 logarithm.
Sourceval exp : (float, 'a) t -> (float, 'a) t exp t computes exponential.
Sourceval exp2 : ('a, 'b) t -> ('a, 'b) t Sourceval sin : ('a, 'b) t -> ('a, 'b) t Sourceval cos : (float, 'a) t -> (float, 'a) t Sourceval tan : (float, 'a) t -> (float, 'a) t Sourceval asin : (float, 'a) t -> (float, 'a) t Sourceval acos : (float, 'a) t -> (float, 'a) t acos t computes arccosine.
Sourceval atan : (float, 'a) t -> (float, 'a) t atan t computes arctangent.
Sourceval atan2 : (float, 'a) t -> (float, 'a) t -> (float, 'a) t atan2 y x computes arctangent of y/x using signs to determine quadrant.
Returns angle in radians in range -π, π. Handles x=0 correctly.
# let y = scalar float32 1. in
let x = scalar float32 1. in
atan2 y x |> item [] |> Float.round
- : float = 1.
# let y = scalar float32 1. in
let x = scalar float32 0. in
atan2 y x |> item [] |> Float.round
- : float = 2.
# let y = scalar float32 0. in
let x = scalar float32 0. in
atan2 y x |> item []
- : float = 0.
Sourceval sinh : (float, 'a) t -> (float, 'a) t sinh t computes hyperbolic sine.
Sourceval cosh : (float, 'a) t -> (float, 'a) t cosh t computes hyperbolic cosine.
Sourceval tanh : (float, 'a) t -> (float, 'a) t tanh t computes hyperbolic tangent.
Sourceval asinh : (float, 'a) t -> (float, 'a) t asinh t computes inverse hyperbolic sine.
Sourceval acosh : (float, 'a) t -> (float, 'a) t acosh t computes inverse hyperbolic cosine.
Sourceval atanh : (float, 'a) t -> (float, 'a) t atanh t computes inverse hyperbolic tangent.
Sourceval hypot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t hypot x y computes sqrt(x² + y²) avoiding overflow.
Uses numerically stable algorithm: max * sqrt(1 + (min/max)²).
# let x = scalar float32 3. in
let y = scalar float32 4. in
hypot x y |> item []
- : float = 5.
# let x = scalar float64 1e200 in
let y = scalar float64 1e200 in
hypot x y |> item [] < Float.infinity
- : bool = true
Sourceval trunc : ('a, 'b) t -> ('a, 'b) t trunc t rounds toward zero.
Removes fractional part. Positive values round down, negative round up.
# let x = create float32 [| 3 |] [| 2.7; -2.7; 2.0 |] in
trunc x
- : (float, float32_elt) t = [2, -2, 2]
Sourceval ceil : (float, 'a) t -> (float, 'a) t ceil t rounds up to nearest integer.
Smallest integer not less than input.
# let x = create float32 [| 4 |] [| 2.1; 2.9; -2.1; -2.9 |] in
ceil x
- : (float, float32_elt) t = [3, 3, -2, -2]
Sourceval floor : (float, 'a) t -> (float, 'a) t floor t rounds down to nearest integer.
Largest integer not greater than input.
# let x = create float32 [| 4 |] [| 2.1; 2.9; -2.1; -2.9 |] in
floor x
- : (float, float32_elt) t = [2, 2, -3, -3]
Sourceval round : (float, 'a) t -> (float, 'a) t round t rounds to nearest integer (half away from zero).
Ties round away from zero (not banker's rounding).
# let x = create float32 [| 4 |] [| 2.5; 3.5; -2.5; -3.5 |] in
round x
- : (float, float32_elt) t = [3, 4, -3, -4]
Sourceval lerp : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t lerp start end_ weight computes linear interpolation.
Returns start + weight * (end_ - start). Weight typically in 0, 1.
# let start = scalar float32 0. in
let end_ = scalar float32 10. in
let weight = scalar float32 0.3 in
lerp start end_ weight |> item []
- : float = 3.
# let start = create float32 [| 2 |] [| 1.; 2. |] in
let end_ = create float32 [| 2 |] [| 5.; 8. |] in
let weight = create float32 [| 2 |] [| 0.25; 0.5 |] in
lerp start end_ weight
- : (float, float32_elt) t = [2, 5]
Sourceval lerp_scalar_weight : ('a, 'b) t -> ('a, 'b) t -> 'a -> ('a, 'b) t lerp_scalar_weight start end_ weight interpolates with scalar weight.
Comparison and Logical Operations
Element-wise comparisons and logical operations.
cmplt t1 t2 returns 1 where t1 < t2, 0 elsewhere.
less t1 t2 is synonym for cmplt.
less_s t scalar checks if each element is less than scalar.
cmpne t1 t2 returns 1 where t1 ≠ t2, 0 elsewhere.
not_equal t1 t2 is synonym for cmpne.
not_equal_s t scalar compares each element with scalar for inequality.
cmpeq t1 t2 returns 1 where t1 = t2, 0 elsewhere.
equal t1 t2 is synonym for cmpeq.
equal_s t scalar compares each element with scalar for equality.
cmpgt t1 t2 returns 1 where t1 > t2, 0 elsewhere.
greater t1 t2 is synonym for cmpgt.
greater_s t scalar checks if each element is greater than scalar.
cmple t1 t2 returns 1 where t1 ≤ t2, 0 elsewhere.
less_equal t1 t2 is synonym for cmple.
less_equal_s t scalar checks if each element is less than or equal to scalar.
cmpge t1 t2 returns 1 where t1 ≥ t2, 0 elsewhere.
greater_equal t1 t2 is synonym for cmpge.
greater_equal_s t scalar checks if each element is greater than or equal to scalar.
array_equal t1 t2 returns scalar 1 if all elements equal, 0 otherwise.
Broadcasts inputs before comparison. Returns 0 if shapes incompatible.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let y = create int32 [| 3 |] [| 1l; 2l; 3l |] in
array_equal x y |> item []
- : int = 1
# let x = create int32 [| 2 |] [| 1l; 2l |] in
let y = create int32 [| 2 |] [| 1l; 3l |] in
array_equal x y |> item []
- : int = 0
Sourceval maximum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t maximum t1 t2 returns element-wise maximum.
Sourceval maximum_s : ('a, 'b) t -> 'a -> ('a, 'b) t maximum_s t scalar returns maximum of each element and scalar.
Sourceval rmaximum_s : 'a -> ('a, 'b) t -> ('a, 'b) t rmaximum_s scalar t is maximum_s t scalar.
Sourceval imaximum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t imaximum target value computes maximum in-place.
Sourceval imaximum_s : ('a, 'b) t -> 'a -> ('a, 'b) t imaximum_s t scalar computes maximum with scalar in-place.
Sourceval minimum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t minimum t1 t2 returns element-wise minimum.
Sourceval minimum_s : ('a, 'b) t -> 'a -> ('a, 'b) t minimum_s t scalar returns minimum of each element and scalar.
Sourceval rminimum_s : 'a -> ('a, 'b) t -> ('a, 'b) t rminimum_s scalar t is minimum_s t scalar.
Sourceval iminimum : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t iminimum target value computes minimum in-place.
Sourceval iminimum_s : ('a, 'b) t -> 'a -> ('a, 'b) t iminimum_s t scalar computes minimum with scalar in-place.
Sourceval logical_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t logical_and t1 t2 computes element-wise AND.
Non-zero values are true.
Sourceval logical_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t logical_or t1 t2 computes element-wise OR.
Sourceval logical_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t logical_xor t1 t2 computes element-wise XOR.
Sourceval logical_not : ('a, 'b) t -> ('a, 'b) t logical_not t computes element-wise NOT.
Returns 1 - x. Non-zero values become 0, zero becomes 1.
# let x = create int32 [| 3 |] [| 0l; 1l; 5l |] in
logical_not x
- : (int32, int32_elt) t = [1, 0, -4]
isinf t returns 1 where infinite, 0 elsewhere.
Detects both positive and negative infinity. Non-float types return all 0s.
# let x = create float32 [| 4 |] [| 1.; Float.infinity; Float.neg_infinity; Float.nan |] in
isinf x
- : (int, uint8_elt) t = [0, 1, 1, 0]
isnan t returns 1 where NaN, 0 elsewhere.
NaN is the only value that doesn't equal itself. Non-float types return all 0s.
# let x = create float32 [| 3 |] [| 1.; Float.nan; Float.infinity |] in
isnan x
- : (int, uint8_elt) t = [0, 1, 0]
isfinite t returns 1 where finite, 0 elsewhere.
Finite means not inf, -inf, or NaN. Non-float types return all 1s.
# let x = create float32 [| 4 |] [| 1.; Float.infinity; Float.nan; -0. |] in
isfinite x
- : (int, uint8_elt) t = [1, 0, 0, 1]
where cond if_true if_false selects elements based on condition.
Returns if_true where cond is non-zero, if_false elsewhere. All three inputs broadcast to common shape.
# let cond = create uint8 [| 3 |] [| 1; 0; 1 |] in
let if_true = create int32 [| 3 |] [| 2l; 3l; 4l |] in
let if_false = create int32 [| 3 |] [| 5l; 6l; 7l |] in
where cond if_true if_false
- : (int32, int32_elt) t = [2, 6, 4]
# let x = create float32 [| 4 |] [| -1.; 2.; -3.; 4. |] in
where (cmpgt x (scalar float32 0.)) x (scalar float32 0.)
- : (float, float32_elt) t = [0, 2, 0, 4]
Sourceval clamp : ?min:'a -> ?max:'a -> ('a, 'b) t -> ('a, 'b) t clamp ?min ?max t limits values to range.
Elements below min become min, above max become max.
Sourceval clip : ?min:'a -> ?max:'a -> ('a, 'b) t -> ('a, 'b) t clip ?min ?max t is synonym for clamp.
Bitwise Operations
Bitwise operations on integer arrays.
Sourceval bitwise_xor : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t bitwise_xor t1 t2 computes element-wise XOR.
Sourceval bitwise_or : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t bitwise_or t1 t2 computes element-wise OR.
Sourceval bitwise_and : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t bitwise_and t1 t2 computes element-wise AND.
Sourceval bitwise_not : ('a, 'b) t -> ('a, 'b) t bitwise_not t computes element-wise NOT.
Sourceval invert : ('a, 'b) t -> ('a, 'b) t Sourceval lshift : ('a, 'b) t -> int -> ('a, 'b) t lshift t shift left-shifts elements by shift bits.
Equivalent to multiplication by 2^shift. Overflow wraps around.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
lshift x 2
- : (int32, int32_elt) t = [4, 8, 12]
Sourceval rshift : ('a, 'b) t -> int -> ('a, 'b) t rshift t shift right-shifts elements by shift bits.
Equivalent to integer division by 2^shift (rounds toward zero).
# let x = create int32 [| 3 |] [| 8l; 9l; 10l |] in
rshift x 2
- : (int32, int32_elt) t = [2, 2, 2]
Reduction Operations
Functions that reduce array dimensions.
Sourceval sum : ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t sum ?axes ?keepdims t sums elements along specified axes.
Default sums all axes (returns scalar). If keepdims is true, retains reduced dimensions with size 1. Negative axes count from end.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
sum x |> item []
- : float = 10.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
sum ~axes:[ 0 ] x
- : (float, float32_elt) t = [4, 6]
# let x = create float32 [| 1; 2 |] [| 1.; 2. |] in
sum ~axes:[ 1 ] ~keepdims:true x
- : (float, float32_elt) t = [[3]]
# let x = create float32 [| 1; 3 |] [| 1.; 2.; 3. |] in
sum ~axes:[ -1 ] x
- : (float, float32_elt) t = [6]
Sourceval max : ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t max ?axes ?keepdims t finds maximum along axes.
Default reduces all axes. NaN propagates (any NaN input gives NaN output).
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
max x |> item []
- : float = 6.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
max ~axes:[ 0 ] x
- : (float, float32_elt) t = [3, 4]
# let x = create float32 [| 1; 2 |] [| 1.; 2. |] in
max ~axes:[ 1 ] ~keepdims:true x
- : (float, float32_elt) t = [[2]]
Sourceval min : ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t min ?axes ?keepdims t finds minimum along axes.
Default reduces all axes. NaN propagates (any NaN input gives NaN output).
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
min x |> item []
- : float = 1.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
min ~axes:[ 0 ] x
- : (float, float32_elt) t = [1, 2]
Sourceval prod : ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t prod ?axes ?keepdims t computes product along axes.
Default multiplies all elements. Empty axes give 1.
# let x = create int32 [| 3 |] [| 2l; 3l; 4l |] in
prod x |> item []
- : int32 = 24l
# let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
prod ~axes:[ 0 ] x
- : (int32, int32_elt) t = [3, 8]
Sourceval cumsum : ?axis:int -> ('a, 'b) t -> ('a, 'b) t cumsum ?axis t computes the inclusive cumulative sum. Defaults to flattening the tensor (row-major order) when axis is omitted.
Sourceval cumprod : ?axis:int -> ('a, 'b) t -> ('a, 'b) t cumprod ?axis t computes the inclusive cumulative product. Defaults to flattening the tensor when axis is omitted.
Sourceval cummax : ?axis:int -> ('a, 'b) t -> ('a, 'b) t cummax ?axis t computes the inclusive cumulative maximum. NaNs propagate for floating-point dtypes. Defaults to flattening when axis is omitted.
Sourceval cummin : ?axis:int -> ('a, 'b) t -> ('a, 'b) t cummin ?axis t computes the inclusive cumulative minimum. NaNs propagate for floating-point dtypes. Defaults to flattening when axis is omitted.
Sourceval mean : ?axes:int list -> ?keepdims:bool -> ('a, 'b) t -> ('a, 'b) t mean ?axes ?keepdims t computes arithmetic mean along axes.
Sum of elements divided by count. NaN propagates.
# let x = create float32 [| 4 |] [| 1.; 2.; 3.; 4. |] in
mean x |> item []
- : float = 2.5
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
mean ~axes:[ 1 ] x
- : (float, float32_elt) t = [2, 5]
Sourceval var :
?axes:int list ->
?keepdims:bool ->
?ddof:int ->
('a, 'b) t ->
('a, 'b) t var ?axes ?keepdims ?ddof t computes variance along axes.
ddof is delta degrees of freedom. Default 0 (population variance). Use 1 for sample variance. Variance = E(X - E[X])² / (N - ddof).
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
var x |> item []
- : float = 2.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
var ~ddof:1 x |> item []
- : float = 2.5
Sourceval std :
?axes:int list ->
?keepdims:bool ->
?ddof:int ->
('a, 'b) t ->
('a, 'b) t std ?axes ?keepdims ?ddof t computes standard deviation.
Square root of variance: sqrt(var(t, ddof)). See var for ddof meaning.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
std x |> item [] |> Float.round
- : float = 1.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
std ~ddof:1 x |> item [] |> Float.round
- : float = 2.
all ?axes ?keepdims t tests if all elements are true (non-zero).
Returns 1 if all elements along axes are non-zero, 0 otherwise.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
all x |> item []
- : int = 1
# let x = create int32 [| 3 |] [| 1l; 0l; 3l |] in
all x |> item []
- : int = 0
# let x = create int32 [| 2; 2 |] [| 1l; 0l; 1l; 1l |] in
all ~axes:[ 1 ] x
- : (int, uint8_elt) t = [0, 1]
any ?axes ?keepdims t tests if any element is true (non-zero).
Returns 1 if any element along axes is non-zero, 0 if all are zero.
# let x = create int32 [| 3 |] [| 0l; 0l; 1l |] in
any x |> item []
- : int = 1
# let x = create int32 [| 3 |] [| 0l; 0l; 0l |] in
any x |> item []
- : int = 0
# let x = create int32 [| 2; 2 |] [| 0l; 0l; 0l; 1l |] in
any ~axes:[ 1 ] x
- : (int, uint8_elt) t = [0, 1]
argmax ?axis ?keepdims t finds indices of maximum values.
Returns index of first occurrence for ties. If axis not specified, operates on flattened tensor and returns scalar.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argmax x |> item []
- : int32 = 4l
# let x = create int32 [| 2; 3 |] [| 1l; 5l; 3l; 2l; 4l; 6l |] in
argmax ~axis:1 x
- : (int32, int32_elt) t = [1, 2]
argmin ?axis ?keepdims t finds indices of minimum values.
Returns index of first occurrence for ties. If axis not specified, operates on flattened tensor and returns scalar.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argmin x |> item []
- : int32 = 1l
# let x = create int32 [| 2; 3 |] [| 5l; 2l; 3l; 1l; 4l; 0l |] in
argmin ~axis:1 x
- : (int32, int32_elt) t = [1, 2]
Sorting and Searching
Functions for sorting arrays and finding indices.
Sourceval sort :
?descending:bool ->
?axis:int ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t sort ?descending ?axis t sorts elements along axis.
Returns (sorted_values, indices) where indices map sorted positions to original positions. Default sorts last axis in ascending order.
Algorithm: Bitonic sort (parallel-friendly, stable)
- Pads to power of 2 with inf/-inf for correctness
- O(n log² n) comparisons, O(log² n) depth
- Stable: preserves relative order of equal elements
- First occurrence wins for duplicate values
Special values:
- NaN: sorted to end (ascending) or beginning (descending)
- inf/-inf: sorted normally
- For integers: uses max/min values for padding
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
sort x
- : (int32, int32_elt) t * (int32, int32_elt) t =
([1, 1, 3, 4, 5], [1, 3, 0, 2, 4])
# let x = create int32 [| 2; 2 |] [| 3l; 1l; 1l; 4l |] in
sort ~descending:true ~axis:0 x
- : (int32, int32_elt) t * (int32, int32_elt) t =
([[3, 4],
[1, 1]], [[0, 1],
[1, 0]])
# let x = create float32 [| 4 |] [| Float.nan; 1.; 2.; Float.nan |] in
let v, _ = sort x in
v
- : (float, float32_elt) t = [1, 2, nan, nan]
Sourceval argsort :
?descending:bool ->
?axis:int ->
('a, 'b) t ->
(int32, int32_elt) t argsort ?descending ?axis t returns indices that would sort tensor.
Equivalent to snd (sort ?descending ?axis t). Returns indices such that taking elements at these indices yields sorted array.
For 1-D: resulti is the index of the i-th smallest element. For N-D: sorts along specified axis independently.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argsort x
- : (int32, int32_elt) t = [1, 3, 0, 2, 4]
# let x = create int32 [| 2; 3 |] [| 3l; 1l; 4l; 2l; 5l; 0l |] in
argsort ~axis:1 x
- : (int32, int32_elt) t = [[1, 0, 2],
[2, 0, 1]]
Linear Algebra
Matrix operations and linear algebra functions.
Most linear algebra functions require floating-point or complex tensors. Functions will raise Invalid_argument if given integer tensors.
Sourceval dot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t dot a b computes generalized dot product.
Important: dot has different broadcasting behavior than matmul:
matmul broadcasts batch dimensionsdot does NOT broadcast; it concatenates non-contracted dimensions
For N-D × M-D arrays: dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) This can result in much larger output arrays than matmul.
Contracts last axis of a with:
- 1-D
b: the only axis (axis 0) - N-D
b: second-to-last axis (axis -2)
Dimension rules:
- 1-D × 1-D: inner product, returns scalar
- 2-D × 2-D: matrix multiplication
- N-D × M-D: batched contraction over all but contracted axes
Supports broadcasting on batch dimensions. Result shape is concatenation of:
- Broadcasted batch dims
- Remaining dims from
a (except last) - Remaining dims from
b (except contracted axis)
# let a = create float32 [| 2 |] [| 1.; 2. |] in
let b = create float32 [| 2 |] [| 3.; 4. |] in
dot a b |> item []
- : float = 11.
# let a = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
let b = create float32 [| 2; 2 |] [| 5.; 6.; 7.; 8. |] in
dot a b
- : (float, float32_elt) t = [[19, 22],
[43, 50]]
# dot (ones float32 [| 3; 4; 5 |]) (ones float32 [| 5; 6 |]) |> shape
- : int array = [|3; 4; 6|]
# dot (ones float32 [| 2; 3; 4; 5 |]) (ones float32 [| 3; 5; 6 |]) |> shape
- : int array = [|2; 3; 4; 6|]
Sourceval matmul : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t matmul a b computes matrix multiplication with broadcasting.
Follows NumPy's @ operator semantics:
- 1-D × 1-D: inner product (returns scalar tensor)
- 1-D × N-D: treated as
1 × k @ ... × k × n → ... × n - N-D × 1-D: treated as
... × m × k @ k × 1 → ... × m - N-D × M-D: batched matrix multiply on last 2 dimensions
Broadcasting rules:
- All dimensions except last 2 are broadcast together
- For 1-D inputs, dimension is temporarily added then removed
- Inner dimensions must match: a.shape
-1 == b.shape-2
Result shape:
- Batch dims: broadcast(a.shape
:-2, b.shape:-2) - Matrix dims:
..., a.shape[-2], b.shape[-1] - 1-D adjustments applied after
# let a = create float32 [| 3 |] [| 1.; 2.; 3. |] in
let b = create float32 [| 3 |] [| 4.; 5.; 6. |] in
matmul a b |> item []
- : float = 32.
# let a = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
let b = create float32 [| 2 |] [| 5.; 6. |] in
matmul a b
- : (float, float32_elt) t = [17, 39]
# let a = create float32 [| 2 |] [| 1.; 2. |] in
let b = create float32 [| 2; 3 |] [| 3.; 4.; 5.; 6.; 7.; 8. |] in
matmul a b
- : (float, float32_elt) t = [15, 18, 21]
# matmul (ones float32 [| 10; 3; 4 |]) (ones float32 [| 10; 4; 5 |]) |> shape
- : int array = [|10; 3; 5|]
# matmul (ones float32 [| 1; 3; 4 |]) (ones float32 [| 5; 4; 2 |]) |> shape
- : int array = [|5; 3; 2|]
Sourceval diagonal :
?offset:int ->
?axis1:int ->
?axis2:int ->
('a, 'b) t ->
('a, 'b) t diagonal ?offset ?axis1 ?axis2 a extracts diagonal from 2-D planes.
offset: diagonal offset (0=main, positive=above, negative=below)axis1, axis2: axes of 2-D planes (default: last two axes)
For 2-D array, returns 1-D array of diagonal elements. For N-D array, returns array with diagonals from each 2-D subarray.
Sourceval matrix_transpose : ('a, 'b) t -> ('a, 'b) t matrix_transpose a transposes matrix dimensions.
Swaps last two axes: ..., M, N -> ..., N, M. For 1-D arrays, returns unchanged.
This is specifically for matrix operations, unlike general transpose which can permute any axes.
Sourceval vdot : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t vdot a b returns dot product of two vectors.
For complex vectors, conjugates first vector before multiplication. Always returns scalar tensor regardless of input shapes. Flattens inputs before computation.
Sourceval vecdot : ?axis:int -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t vecdot ?axis x1 x2 computes vector dot product along an axis.
axis: axis along which to compute dot product (default: -1)
Unlike vdot which always flattens, vecdot computes dot products along specified axis with broadcasting support.
Sourceval inner : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t inner a b computes inner product over last axes.
For 1-D arrays, this is ordinary inner product. For higher dimensions, sums products over last axes of a and b.
Sourceval outer : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t outer a b computes outer product of two vectors.
Given vectors ai and bj, produces matrix Mi,j = ai * bj. Input tensors are flattened if not already 1-D.
# outer (create float32 [|2|] [|1.; 2.|]) (create float32 [|3|] [|3.; 4.; 5.|])
- : (float, float32_elt) t = [[3, 4, 5],
[6, 8, 10]]
Sourceval tensordot :
?axes:(int list * int list) ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t tensordot ?axes a b computes tensor contraction along specified axes.
axes: pair of axis lists to contract (default: last of a, first of b)
Generalizes matrix multiplication to arbitrary dimensions.
Sourceval einsum : string -> ('a, 'b) t array -> ('a, 'b) t einsum subscripts operands evaluates Einstein summation convention.
Subscripts string specifies contraction, e.g., "ij,jk->ik" for matmul. Repeated indices are summed, free indices form output dimensions.
# let a = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
let b = create float32 [| 3; 2 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
shape (einsum "ij,jk->ik" [|a; b|]) (* matrix multiplication *)
- : int array = [|2; 2|]
# let a = create float32 [| 3; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.; 9. |] in
shape (einsum "ii->i" [|a|]) (* diagonal *)
- : int array = [|3|]
# let a = create float32 [| 3; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.; 9. |] in
shape (einsum "ij->ji" [|a|]) (* transpose *)
- : int array = [|3; 3|]
Sourceval kron : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t kron a b computes Kronecker product.
Result has shape a.shape[i] * b.shape[i] for i in range(ndim). Each element ai,j is replaced by ai,j * b.
Sourceval multi_dot : ('a, 'b) t array -> ('a, 'b) t multi_dot arrays computes chained matrix multiplication optimally.
Automatically selects the association order that minimizes computational cost. Much more efficient than repeated matmul for chains of 3+ matrices.
Sourceval matrix_power : ('a, 'b) t -> int -> ('a, 'b) t matrix_power a n raises square matrix to integer power.
- n > 0: a @ a @ ... @ a (n times)
- n = 0: identity matrix
- n < 0: inv(a) @ inv(a) @ ... @ inv(a) (|n| times)
Sourceval cross : ?axis:int -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t cross ?axis a b returns cross product of 3-element vectors.
axis: axis containing vectors (default: last axis)
Matrix Decompositions
Sourceval cholesky : ?upper:bool -> ('a, 'b) t -> ('a, 'b) t cholesky ?upper a computes Cholesky decomposition.
upper: return upper triangular if true (default: false)
Returns L (or U) such that a = L @ L.T (or U.T @ U).
Sourceval qr :
?mode:[ `Complete | `Reduced ] ->
('a, 'b) t ->
('a, 'b) t * ('a, 'b) t qr ?mode a computes QR decomposition.
mode: `Reduced for economy mode (default), `Complete for full
Returns (Q, R) where a = Q @ R, Q is orthogonal, R is upper triangular.
svd ?full_matrices a computes singular value decomposition.
full_matrices: compute full U, V matrices (default: false)
Returns (U, S, Vh) where a = U @ diag(S) @ Vh. S is 1-D array of singular values in descending order.
svdvals a returns singular values only.
More efficient than svd when only singular values are needed.
Eigenvalues and Eigenvectors
eig a computes eigenvalues and right eigenvectors.
Returns (eigenvalues, eigenvectors) for general square matrix. For real float32/float64 inputs, outputs are complex32/complex64 since real matrices can have complex eigenvalues.
eigh ?uplo a computes eigenvalues for symmetric/Hermitian matrix.
uplo: use upper (`U) or lower (`L) triangle (default: `L`)
Returns (eigenvalues, eigenvectors) in ascending order. For real symmetric matrices, eigenvalues are guaranteed real. More efficient than eig for symmetric matrices.
eigvals a computes eigenvalues only.
For real inputs, returns complex tensor since eigenvalues may be complex. More efficient than eig when eigenvectors not needed.
eigvalsh ?uplo a computes eigenvalues for symmetric/Hermitian matrix.
For real symmetric inputs, returns real eigenvalues. More efficient than eigvals for symmetric matrices.
Norms and Condition Numbers
Sourceval norm :
?ord:
[ `Fro
| `Nuc
| `One
| `Two
| `Inf
| `NegOne
| `NegTwo
| `NegInf
| `P of float ] ->
?axes:int list ->
?keepdims:bool ->
('a, 'b) t ->
('a, 'b) t norm ?ord ?axes ?keepdims x computes matrix or vector norm.
ord: norm type (default: Frobenius for matrices, 2-norm for vectors)`Fro: Frobenius norm`Nuc: nuclear norm (sum of singular values)`One: max column sum (for matrices)`Two: spectral norm (largest singular value)`Inf: max row sum (for matrices)`NegOne: min column sum`NegTwo: smallest singular value`NegInf: min row sum`P p: p-norm for vectorsaxes: axes to compute norm over. For matrix norms, must be 2-element listkeepdims: keep reduced dimensions as size 1
Sourceval cond :
?p:[ `One | `Two | `Inf | `NegOne | `NegTwo | `NegInf | `Fro ] ->
('a, 'b) t ->
('a, 'b) t cond ?p x computes condition number.
p: norm to use (default: 2-norm)`One: 1-norm (max column sum)`Two: 2-norm (max singular value)`Inf: infinity norm (max row sum)`NegOne: -1 norm (min column sum)`NegTwo: -2 norm (min singular value)`NegInf: -infinity norm (min row sum)`Fro: Frobenius norm
Returns ratio of largest to smallest norm.
Sourceval det : ('a, 'b) t -> ('a, 'b) t det a computes determinant of square matrix.
slogdet a computes sign and log of determinant.
Returns (sign, logdet) where det(a) = sign * exp(logdet). More stable than det for matrices with very small/large determinants.
Sourceval matrix_rank :
?tol:float ->
?rtol:float ->
?hermitian:bool ->
('a, 'b) t ->
int matrix_rank ?tol ?rtol ?hermitian a returns rank of matrix.
tol: absolute tolerance for small singular valuesrtol: relative tolerance (default: max(M,N) * eps * largest_singular_value)hermitian: if true, use more efficient algorithm for Hermitian matrices
Counts singular values greater than tolerance.
Sourceval trace : ?offset:int -> ('a, 'b) t -> ('a, 'b) t trace ?offset a returns sum along diagonal.
offset: diagonal offset (default: 0, positive for upper diagonals)
Solving Linear Systems
Sourceval solve : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t solve a b solves linear system a @ x = b for x.
Supports batched operations when a, b have compatible batch dimensions.
Sourceval lstsq :
?rcond:float ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t * ('a, 'b) t * int * (float, float64_elt) t lstsq ?rcond a b computes least-squares solution to a @ x = b.
rcond: cutoff for small singular values (default: machine precision)
Returns (solution, residuals, rank, singular_values). Handles over/under-determined systems.
Sourceval inv : ('a, 'b) t -> ('a, 'b) t inv a computes inverse of square matrix.
Sourceval pinv : ?rtol:float -> ?hermitian:bool -> ('a, 'b) t -> ('a, 'b) t pinv ?rtol ?hermitian a computes Moore-Penrose pseudoinverse.
rtol: relative tolerance for small singular valueshermitian: if true, use more efficient algorithm for Hermitian matrices
Handles non-square and singular matrices.
Sourceval tensorsolve : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t tensorsolve ?axes a b solves tensor equation a x = b for x.
axes: axes in a to reorder to end (default: product of b.ndim rightmost axes)
Solves for x such that tensordot(a, x, axes) = b.
Sourceval tensorinv : ?ind:int -> ('a, 'b) t -> ('a, 'b) t tensorinv ?ind a computes 'inverse' of N-D array.
ind: number of first indices involved in inverse sum (default: 2)
Result is such that tensordot(a, a_inv, ind) = I.
Fast Fourier Transform (FFT) and related signal processing functions.
Sourcetype fft_norm = [ | `Backward| `Forward| `Ortho
] FFT normalization mode:
`Backward: normalize by 1/n on inverse transform (default)`Forward: normalize by 1/n on forward transform`Ortho: normalize by 1/sqrt(n) on both transforms
fft ?axis ?n ?norm x computes discrete Fourier transform over specified axis.
axis: axis to transform (default: last axis)n: Length of the transformed axis of the outputnorm: normalization mode (default: `Backward)
Computing 1D FFT of a signal:
# let x = create complex64 [|4|]
[|Complex.{re=0.; im=0.}; {re=1.; im=0.};
{re=2.; im=0.}; {re=3.; im=0.}|] in
let result = fft ~axis:0 x in
shape result
- : int array = [|4|]
ifft ?axis ?s ?norm x computes inverse discrete Fourier transform.
axis: axis to transform (default: last axis)n: Length of the transformed axis of the outputnorm: normalization mode (default: `Backward)
fft2 ?axes ?s ?norm x computes 2-dimensional FFT.
Transforms last two axes by default. Truncates or pads to shape s if given.
Computing 2D FFT of a 2x2 matrix:
# let x = create complex64 [|2; 2|]
[|Complex.{re=1.; im=0.}; {re=2.; im=0.};
{re=3.; im=0.}; {re=4.; im=0.}|] in
shape (fft2 x)
- : int array = [|2; 2|]
ifft2 ?axes ?s ?norm x computes 2-dimensional inverse FFT.
fftn ?axes ?s ?norm x computes N-dimensional FFT.
Transforms all axes by default.
ifftn ?axes ?s ?norm x computes N-dimensional inverse FFT.
rfft ?axis ?n ?norm x computes FFT of real input.
Returns only non-redundant positive frequencies. Output size along last transformed axis is n/2+1 where n is input size.
axes: axes to transform (default: last axis)s: shape to truncate/pad to before transformnorm: normalization mode (default: `Backward)
Computing real FFT:
# let x = create float64 [|4|] [|0.; 1.; 2.; 3.|] in
let result = rfft ~axis:0 x in
shape result
- : int array = [|3|]
irfft ?axis ?n ?norm x computes inverse FFT returning real output.
Assumes Hermitian symmetry.
axis: axis to transform (default: last axis)n: output shape along transformed axesnorm: normalization mode (default: `Backward)
rfft2 ?axes ?s ?norm x computes 2D FFT of real input.
irfft2 ?axes ?s ?norm x computes 2D inverse FFT returning real output.
rfftn ?axes ?s ?norm x computes N-dimensional FFT of real input.
irfftn ?axes ?s ?norm x computes N-dimensional inverse FFT returning real output.
hfft x ~n ~axis computes FFT of Hermitian signal.
Interprets input as positive frequencies of Hermitian signal.
ihfft x ~n ~axis computes inverse FFT for Hermitian output.
fftfreq ?d n returns DFT sample frequencies.
For window length n and sample spacing d, returns frequencies 0, 1, ..., n/2-1, -n/2, ..., -1 / (d*n) if n is even.
Getting frequencies for 4-point FFT:
# Nx.fftfreq 4
- : (float, float64_elt) t = [0, 0.25, -0.5, -0.25]
rfftfreq ?d n returns positive DFT frequencies.
Returns 0, 1, ..., n/2 / (d*n).
Sourceval fftshift : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t fftshift x ?axes shifts zero-frequency component to center.
Shifts all axes by default. For visualization of frequency spectra.
Centering frequency spectrum:
# let freqs = fftfreq 5 in
fftshift freqs
- : (float, float64_elt) t = [-0.4, -0.2, 0, 0.2, 0.4]
Sourceval ifftshift : ?axes:int list -> ('a, 'b) t -> ('a, 'b) t ifftshift x ?axes undoes fftshift.
Activation Functions
Neural network activation functions.
Sourceval relu : ('a, 'b) t -> ('a, 'b) t relu t applies Rectified Linear Unit: max(0, x).
# let x = create float32 [| 5 |] [| -2.; -1.; 0.; 1.; 2. |] in
relu x
- : (float, float32_elt) t = [0, 0, 0, 1, 2]
Sourceval relu6 : (float, 'a) t -> (float, 'a) t relu6 t applies ReLU6: min(max(0, x), 6).
Bounded ReLU used in mobile networks. Clips to 0, 6 range.
# let x = create float32 [| 3 |] [| -1.; 3.; 8. |] in
relu6 x
- : (float, float32_elt) t = [0, 3, 6]
Sourceval sigmoid : (float, 'a) t -> (float, 'a) t sigmoid t applies logistic sigmoid: 1 / (1 + exp(-x)).
Output in range (0, 1). Symmetric around x=0 where sigmoid(0) = 0.5.
# sigmoid (scalar float32 0.) |> item []
- : float = 0.5
# sigmoid (scalar float32 10.) |> item [] |> Float.round
- : float = 1.
# sigmoid (scalar float32 (-10.)) |> item [] |> Float.round
- : float = 0.
Sourceval hard_sigmoid :
?alpha:float ->
?beta:float ->
(float, 'a) t ->
(float, 'a) t hard_sigmoid ?alpha ?beta t applies piecewise linear sigmoid approximation.
Default alpha = 1/6, beta = 0.5.
Sourceval softplus : (float, 'a) t -> (float, 'a) t softplus t applies smooth ReLU: log(1 + exp(x)).
Smooth approximation to ReLU. Always positive, differentiable everywhere.
# softplus (scalar float32 0.) |> item [] |> Float.round
- : float = 1.
# softplus (scalar float32 100.) |> item [] |> Float.round
- : float = infinity
Sourceval silu : (float, 'a) t -> (float, 'a) t silu t applies Sigmoid Linear Unit: x * sigmoid(x).
Also called Swish. Smooth, non-monotonic activation.
# silu (scalar float32 0.) |> item []
- : float = 0.
# silu (scalar float32 1.) |> item [] |> Float.round
- : float = 1.
# silu (scalar float32 (-1.)) |> item [] |> Float.round
- : float = -0.
Sourceval hard_silu : (float, 'a) t -> (float, 'a) t hard_silu t applies x * hard_sigmoid(x).
Piecewise linear approximation of SiLU. More efficient than SiLU.
# let x = create float32 [| 3 |] [| -3.; 0.; 3. |] in
hard_silu x
- : (float, float32_elt) t = [-0, 0, 3]
Sourceval log_sigmoid : (float, 'a) t -> (float, 'a) t log_sigmoid t computes log(sigmoid(x)).
Numerically stable version of log(1/(1+exp(-x))). Always negative.
# log_sigmoid (scalar float32 0.) |> item [] |> Float.round
- : float = -1.
# log_sigmoid (scalar float32 100.) |> item [] |> Float.abs |> (fun x -> x < 0.001)
- : bool = true
Sourceval leaky_relu : ?negative_slope:float -> (float, 'a) t -> (float, 'a) t leaky_relu ?negative_slope t applies Leaky ReLU.
Default negative_slope = 0.01. Returns x if x > 0, else negative_slope * x.
Sourceval hard_tanh : (float, 'a) t -> (float, 'a) t hard_tanh t clips values to -1, 1.
Linear in -1, 1, saturates outside. Cheaper than tanh.
# let x = create float32 [| 5 |] [| -2.; -0.5; 0.; 0.5; 2. |] in
hard_tanh x
- : (float, float32_elt) t = [-1, -0.5, 0, 0.5, 1]
Sourceval elu : ?alpha:float -> (float, 'a) t -> (float, 'a) t elu ?alpha t applies Exponential Linear Unit.
Default alpha = 1.0. Returns x if x > 0, else alpha * (exp(x) - 1). Smooth for x < 0, helps with vanishing gradients.
# elu (scalar float32 1.) |> item []
- : float = 1.
# elu (scalar float32 0.) |> item []
- : float = 0.
# elu (scalar float32 (-1.)) |> item [] |> Float.round
- : float = -1.
Sourceval selu : (float, 'a) t -> (float, 'a) t selu t applies Scaled ELU with fixed alpha=1.67326, lambda=1.0507.
Self-normalizing activation. Preserves mean 0 and variance 1 in deep networks under certain conditions.
# selu (scalar float32 0.) |> item []
- : float = 0.
# selu (scalar float32 1.) |> item [] |> Float.round
- : float = 1.
Sourceval softmax : ?axes:int list -> (float, 'a) t -> (float, 'a) t softmax ?axes t applies softmax normalization.
Default axis -1. Computes exp(x - max) / sum(exp(x - max)) for numerical stability. Output sums to 1 along specified axes.
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
softmax x |> to_array |> Array.map Float.round
- : float array = [|0.; 0.; 1.|]
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
sum (softmax x) |> item []
- : float = 1.
Sourceval erf : (float, 'a) t -> (float, 'a) t erf t computes the error function.
The error function erf(x) = (2/√π) ∫₀ˣ e^(-t²) dt. Uses Abramowitz and Stegun approximation for numerical stability.
# erf (scalar float32 0.) |> item []
- : float = 0.
# let result = erf (scalar float32 1.) |> item [] in
Float.round (result *. 10000.) /. 10000. (* Round to 4 decimals *)
- : float = 0.8427
Sourceval gelu : (float, 'a) t -> (float, 'a) t gelu t applies exact Gaussian Error Linear Unit.
GELU(x) = 0.5 * x * (1 + erf(x / √2))
This is the exact GELU using error function, more numerically stable than the approximation for gradient computation.
# gelu (scalar float32 0.) |> item []
- : float = 0.
# gelu (scalar float32 1.) |> item [] |> Float.round
- : float = 1.
Sourceval gelu_approx : (float, 'a) t -> (float, 'a) t gelu_approx t applies Gaussian Error Linear Unit approximation.
Smooth activation: x * Φ(x) where Φ is Gaussian CDF. This uses tanh approximation for efficiency.
# gelu_approx (scalar float32 0.) |> item []
- : float = 0.
# gelu_approx (scalar float32 1.) |> item [] |> Float.round
- : float = 1.
Sourceval softsign : (float, 'a) t -> (float, 'a) t softsign t computes x / (|x| + 1).
Similar to tanh but computationally cheaper. Range (-1, 1).
# let x = create float32 [| 3 |] [| -10.; 0.; 10. |] in
softsign x
- : (float, float32_elt) t = [-0.909091, 0, 0.909091]
Sourceval mish : (float, 'a) t -> (float, 'a) t mish t applies Mish activation: x * tanh(softplus(x)).
Self-regularizing non-monotonic activation. Smoother than ReLU.
# mish (scalar float32 0.) |> item [] |> Float.abs |> (fun x -> x < 0.001)
- : bool = true
# mish (scalar float32 (-10.)) |> item [] |> Float.round
- : float = -0.
Convolution and Pooling
Neural network convolution and pooling operations.
Sourceval im2col :
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t ->
('a, 'b) t im2col ~kernel_size ~stride ~dilation ~padding t extracts sliding local blocks from tensor.
Extracts patches of size kernel_size from the input tensor at the specified stride and dilation.
kernel_size: size of sliding blocks to extractstride: step between consecutive blocksdilation: spacing between kernel elementspadding: (before, after) padding for each spatial dimension
For a 4D input batch; channels; height; width, produces output shape batch; channels * kh * kw; num_patches_h; num_patches_w where kh, kw are kernel dimensions and num_patches depends on stride and padding.
# let x = arange_f float32 0. 16. 1. |> reshape [| 1; 1; 4; 4 |] in
im2col ~kernel_size:[|2; 2|] ~stride:[|1; 1|]
~dilation:[|1; 1|] ~padding:[|(0, 0); (0, 0)|] x |> shape
- : int array = [|1; 4; 9|]
Sourceval col2im :
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t ->
('a, 'b) t col2im ~output_size ~kernel_size ~stride ~dilation ~padding t combines sliding local blocks into tensor.
This is the inverse of im2col. Accumulates values from the unfolded representation back into spatial dimensions. Overlapping regions are summed.
output_size: target spatial dimensions height; widthkernel_size: size of sliding blocksstride: step between consecutive blocksdilation: spacing between kernel elementspadding: (before, after) padding for each spatial dimension
For input shape batch; channels * kh * kw; num_patches_h; num_patches_w, produces output batch; channels; height; width.
# let unfolded = create float32 [| 1; 4; 9 |] (Array.init 36 Float.of_int) in
col2im ~output_size:[|4; 4|] ~kernel_size:[|2; 2|]
~stride:[|1; 1|] ~dilation:[|1; 1|]
~padding:[|(0, 0); (0, 0)|] unfolded |> shape
- : int array = [|1; 1; 4; 4|]
Sourceval correlate1d :
?groups:int ->
?stride:int ->
?padding_mode:[ `Full | `Same | `Valid ] ->
?dilation:int ->
?fillvalue:float ->
?bias:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t correlate1d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 1D cross-correlation (no kernel flip).
x: input batch_size; channels_in; widthw: weights channels_out; channels_in/groups; kernel_widthbias: optional per-channel bias channels_outgroups: split input/output channels into groups (default 1)stride: step between windows (default 1)padding_mode: `Valid (no pad), `Same (preserve size), `Full (all overlaps)dilation: spacing between kernel elements (default 1)fillvalue: padding value (default 0.0)
Output width depends on padding:
- `Valid: (width - dilation*(kernel-1) - 1)/stride + 1
- `Same: width/stride (rounded up)
- `Full: (width + dilation*(kernel-1) - 1)/stride + 1
# let x = create float32 [| 1; 1; 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
let w = create float32 [| 1; 1; 3 |] [| 1.; 0.; -1. |] in
correlate1d x w |> shape
- : int array = [|1; 1; 3|]
Sourceval correlate2d :
?groups:int ->
?stride:(int * int) ->
?padding_mode:[ `Full | `Same | `Valid ] ->
?dilation:(int * int) ->
?fillvalue:float ->
?bias:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t correlate2d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 2D cross-correlation (no kernel flip).
x: input batch; channels_in; height; widthw: weights channels_out; channels_in/groups; kernel_h; kernel_wbias: optional per-channel bias channels_outstride: (stride_h, stride_w) step between windows (default (1,1))dilation: (dilation_h, dilation_w) kernel spacing (default (1,1))padding_mode: `Valid (no pad), `Same (preserve size), `Full (all overlaps)
Uses Winograd F(4,3) for 3×3 kernels with stride 1 when beneficial. For `Same` with even kernels, pads more on bottom/right (SciPy convention).
# let image = ones float32 [| 1; 1; 5; 5 |] in
let sobel_x = create float32 [| 1; 1; 3; 3 |] [| 1.; 0.; -1.; 2.; 0.; -2.; 1.; 0.; -1. |] in
correlate2d image sobel_x |> shape
- : int array = [|1; 1; 3; 3|]
Sourceval convolve1d :
?groups:int ->
?stride:int ->
?padding_mode:[< `Full | `Same | `Valid Valid ] ->
?dilation:int ->
?fillvalue:'a ->
?bias:('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t convolve1d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 1D convolution (flips kernel before correlation).
Same parameters as correlate1d but flips kernel. For `Same` with even kernels, pads more on left (NumPy convention).
# let x = create float32 [| 1; 1; 3 |] [| 1.; 2.; 3. |] in
let w = create float32 [| 1; 1; 2 |] [| 4.; 5. |] in
convolve1d x w
- : (float, float32_elt) t = [[[13, 22]]]
Sourceval convolve2d :
?groups:int ->
?stride:(int * int) ->
?padding_mode:[< `Full | `Same | `Valid Valid ] ->
?dilation:(int * int) ->
?fillvalue:'a ->
?bias:('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t convolve2d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 2D convolution (flips kernel before correlation).
Same parameters as correlate2d but flips kernel horizontally and vertically. For `Same` with even kernels, pads more on top/left.
# let image = ones float32 [| 1; 1; 5; 5 |] in
let gaussian = create float32 [| 1; 1; 3; 3 |] [| 1.; 2.; 1.; 2.; 4.; 2.; 1.; 2.; 1. |] in
convolve2d image (mul_s gaussian (1. /. 16.)) |> shape
- : int array = [|1; 1; 3; 3|]
Sourceval avg_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?count_include_pad:bool ->
(float, 'a) t ->
(float, 'a) t avg_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?count_include_pad x applies 1D average pooling.
kernel_size: pooling window sizestride: step between windows (default: kernel_size)dilation: spacing between kernel elements (default 1)padding_spec: same as convolution padding modesceil_mode: use ceiling for output size calculation (default false)count_include_pad: include padding in average (default true)
Input shape: batch; channels; width Output width: (width + 2*pad - dilation*(kernel-1) - 1)/stride + 1
# let x = create float32 [| 1; 1; 4 |] [| 1.; 2.; 3.; 4. |] in
avg_pool1d ~kernel_size:2 x
- : (float, float32_elt) t = [[[1.5, 3.5]]]
Sourceval avg_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?count_include_pad:bool ->
(float, 'a) t ->
(float, 'a) t avg_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?count_include_pad x applies 2D average pooling.
kernel_size: (height, width) of pooling windowstride: (stride_h, stride_w) (default: kernel_size)dilation: (dilation_h, dilation_w) (default (1,1))count_include_pad: whether padding contributes to denominator
Input shape: batch; channels; height; width
# let x = create float32 [| 1; 1; 2; 2 |] [| 1.; 2.; 3.; 4. |] in
avg_pool2d ~kernel_size:(2, 2) x
- : (float, float32_elt) t = [[[[2.5]]]]
Sourceval max_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t option max_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 1D max pooling.
return_indices: if true, also returns indices of max values for unpooling- Other parameters same as
avg_pool1d
Returns (pooled_values, Some indices) if return_indices=true, otherwise (pooled_values, None). Indices are flattened positions in input.
# let x = create float32 [| 1; 1; 4 |] [| 1.; 3.; 2.; 4. |] in
let vals, idx = max_pool1d ~kernel_size:2 ~return_indices:true x in
vals, idx
- : (float, float32_elt) t * (int32, int32_elt) t option =
([[[3, 4]]], Some [[[1, 1]]])
Sourceval max_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t option max_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 2D max pooling.
Parameters same as max_pool1d but for 2D. Indices encode flattened position within each pooling window.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 5.; 6.; 3.; 4.; 7.; 8.; 9.; 10.; 13.; 14.; 11.; 12.; 15.; 16. |] in
let vals, _ = max_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) x in
vals
- : (float, float32_elt) t = [[[[4, 8],
[12, 16]]]]
Sourceval min_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t option min_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 1D min pooling.
return_indices: if true, also returns indices of min values (currently returns None)- Other parameters same as
avg_pool1d
Returns (pooled_values, None). Index tracking not yet implemented.
# let x = create float32 [| 1; 1; 4 |] [| 4.; 2.; 3.; 1. |] in
let vals, _ = min_pool1d ~kernel_size:2 x in
vals
- : (float, float32_elt) t = [[[2, 1]]]
Sourceval min_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t option min_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 2D min pooling.
Parameters same as min_pool1d but for 2D. Commonly used for morphological erosion operations in image processing.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 5.; 6.; 3.; 4.; 7.; 8.; 9.; 10.; 13.; 14.; 11.; 12.; 15.; 16. |] in
let vals, _ = min_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) x in
vals
- : (float, float32_elt) t = [[[[1, 5],
[9, 13]]]]
Sourceval max_unpool1d :
(int, uint8_elt) t ->
('a, 'b) t ->
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?output_size_opt:int array ->
unit ->
(int, uint8_elt) t max_unpool1d indices values ~kernel_size ?stride ?dilation ?padding_spec ?output_size_opt () reverses max pooling.
indices: indices from max_pool1d with return_indices=truevalues: pooled values to place at indexed positionskernel_size, stride, dilation, padding_spec: must match original pooloutput_size_opt: exact output shape (inferred if not provided)
Places values at positions indicated by indices, fills rest with zeros. Output size computed from input unless explicitly specified.
# let x = create float32 [| 1; 1; 4 |] [| 1.; 3.; 2.; 4. |] in
let pooled, _ = max_pool1d ~kernel_size:2 x in
pooled
- : (float, float32_elt) t = [[[3, 4]]]
Sourceval max_unpool2d :
(int, uint8_elt) t ->
('a, 'b) t ->
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?output_size_opt:int array ->
unit ->
(int, uint8_elt) t max_unpool2d indices values ~kernel_size ?stride ?dilation ?padding_spec ?output_size_opt () reverses 2D max pooling.
Same as max_unpool1d but for 2D. Indices encode position within each pooling window. Useful for architectures like segmentation networks that need to "remember" where maxima came from.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.;
9.; 10.; 11.; 12.; 13.; 14.; 15.; 16. |] in
let pooled, _ = max_pool2d ~kernel_size:(2,2) x in
pooled
- : (float, float32_elt) t = [[[[6, 8],
[14, 16]]]]
one_hot ~num_classes indices creates one-hot encoding.
Adds new last dimension of size num_classes. Values must be in [0, num_classes). Out-of-range indices produce zero vectors.
# let indices = create int32 [| 3 |] [| 0l; 1l; 3l |] in
one_hot ~num_classes:4 indices
- : (int, uint8_elt) t = [[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]]
# let indices = create int32 [| 2; 2 |] [| 0l; 2l; 1l; 0l |] in
one_hot ~num_classes:3 indices |> shape
- : int array = [|2; 2; 3|]
Iteration and Mapping
Functions to iterate over and transform arrays.
Sourceval map_item : ('a -> 'a) -> ('a, 'b) t -> ('a, 'b) t map_item f t applies f to each element.
Operates on contiguous data directly. Type-preserving only.
Sourceval iter_item : ('a -> unit) -> ('a, 'b) t -> unit iter_item f t applies f to each element for side effects.
Sourceval fold_item : ('a -> 'b -> 'a) -> 'a -> ('b, 'c) t -> 'a fold_item f init t folds f over elements.
Sourceval map : (('a, 'b) t -> ('a, 'b) t) -> ('a, 'b) t -> ('a, 'b) t map f t applies tensor function f to each element as scalar tensor.
Sourceval iter : (('a, 'b) t -> unit) -> ('a, 'b) t -> unit iter f t applies tensor function f to each element.
Sourceval fold : ('a -> ('b, 'c) t -> 'a) -> 'a -> ('b, 'c) t -> 'a fold f init t folds tensor function over elements.
Printing and Display
Functions to display arrays and convert to strings.
pp_data fmt t pretty-prints tensor data.
format_to_string pp x converts using pretty-printer.
print_with_formatter pp x prints using formatter.
Sourceval data_to_string : ('a, 'b) t -> string data_to_string t converts tensor data to string.
Sourceval print_data : ('a, 'b) t -> unit print_data t prints tensor data to stdout.
pp_dtype fmt dt pretty-prints dtype.
dtype_to_string dt converts dtype to string.
Sourceval shape_to_string : int array -> string shape_to_string shape formats shape as "2x3x4".
pp_shape fmt shape pretty-prints shape.
pp fmt t pretty-prints tensor info and data.
Sourceval print : ('a, 'b) t -> unit print t prints tensor info and data to stdout.
Sourceval to_string : ('a, 'b) t -> string to_string t converts tensor info and data to string.