package rune

  1. Overview
  2. Docs

Module RuneSource

Functional transformations for Nx tensors.

Rune provides automatic differentiation (forward and reverse mode), vectorising maps, and gradient checking. It operates by intercepting Nx tensor operations via OCaml 5 effect handlers — no special tensor type is needed.

Terminology.

  • Primal: the input value at which a derivative is evaluated.
  • Tangent: the directional derivative seed (forward mode).
  • Cotangent: the adjoint seed propagated backward (reverse mode).
  • JVP: Jacobian-vector product (forward-mode AD).
  • VJP: vector-Jacobian product (reverse-mode AD).

Reverse-mode AD

Compute gradients of scalar-valued functions via reverse-mode (backpropagation). The function f must return a scalar tensor; the gradient has the same shape as the input.

Sourceval grad : (('a, 'b) Nx.t -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t -> ('a, 'b) Nx.t

grad f x is the gradient of scalar-valued f at x.

Equivalent to snd (value_and_grad f x).

See also grads, value_and_grad.

Sourceval grads : (('a, 'b) Nx.t list -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t list -> ('a, 'b) Nx.t list

grads f xs is the list of gradients of scalar-valued f with respect to each tensor in xs. The i-th element of the result has the same shape as the i-th element of xs.

See also grad, value_and_grads.

Sourceval value_and_grad : (('a, 'b) Nx.t -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t * ('a, 'b) Nx.t

value_and_grad f x is (f x, grad f x), computed in a single forward-backward pass.

See also value_and_grad_aux.

Sourceval value_and_grad_aux : (('a, 'b) Nx.t -> ('c, 'd) Nx.t * 'e) -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t * ('a, 'b) Nx.t * 'e

value_and_grad_aux f x is (y, g, aux) where (y, aux) = f x and g is the gradient of y with respect to x. The auxiliary output aux is carried through but not differentiated.

See also value_and_grads_aux.

Sourceval value_and_grads : (('a, 'b) Nx.t list -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t list -> ('c, 'd) Nx.t * ('a, 'b) Nx.t list

value_and_grads f xs is (f xs, grads f xs), computed in a single forward-backward pass.

See also value_and_grads_aux.

Sourceval value_and_grads_aux : (('a, 'b) Nx.t list -> ('c, 'd) Nx.t * 'e) -> ('a, 'b) Nx.t list -> ('c, 'd) Nx.t * ('a, 'b) Nx.t list * 'e

value_and_grads_aux f xs is (y, gs, aux) where (y, aux) = f xs and gs is the list of gradients of y with respect to each tensor in xs. The auxiliary output aux is carried through but not differentiated.

See also value_and_grad_aux.

Sourceval vjp : (('a, 'b) Nx.t -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t -> ('c, 'd) Nx.t * ('a, 'b) Nx.t

vjp f x v is (y, g) where y = f x and g = v{^T} J{_f}(x) (vector-Jacobian product). Unlike grad, f need not return a scalar — the cotangent v must have the same shape as y.

See also vjps.

Sourceval vjps : (('a, 'b) Nx.t list -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t list -> ('c, 'd) Nx.t -> ('c, 'd) Nx.t * ('a, 'b) Nx.t list

vjps f xs v is like vjp for functions with multiple inputs. Returns (y, gs) where each gradient in gs corresponds to one input in xs.

Forward-mode AD

Compute Jacobian-vector products by propagating tangent vectors alongside primal values. Forward mode is efficient when the number of inputs is small relative to the number of outputs.

Sourceval jvp : (('a, 'b) Nx.t -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t * ('c, 'd) Nx.t

jvp f x v is (y, t) where y = f x and t = J{_f}(x) v (Jacobian-vector product). The tangent v must have the same shape as x.

See also jvps, jvp_aux.

Sourceval jvp_aux : (('a, 'b) Nx.t -> ('c, 'd) Nx.t * 'e) -> ('a, 'b) Nx.t -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t * ('c, 'd) Nx.t * 'e

jvp_aux f x v is like jvp but for functions with auxiliary output. Returns (y, t, aux) where aux is carried through but not differentiated.

Sourceval jvps : (('a, 'b) Nx.t list -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t list -> ('a, 'b) Nx.t list -> ('c, 'd) Nx.t * ('c, 'd) Nx.t

jvps f xs vs is like jvp for functions with multiple inputs. Each tangent in vs must have the same shape as the corresponding primal in xs.

Stopping gradients

Sourceval no_grad : (unit -> 'a) -> 'a

no_grad f evaluates f () without recording operations for automatic differentiation. All tensors produced inside f are treated as constants by enclosing gradient computations.

Sourceval detach : ('a, 'b) Nx.t -> ('a, 'b) Nx.t

detach x is a copy of x that is treated as a constant with respect to automatic differentiation.

See also no_grad.

Custom differentiation rules

Override automatic differentiation with user-supplied forward and backward (or tangent) rules. Useful for implicit differentiation, surrogate gradients, and other computations where the derivative is algorithmically distinct from the primal.

Under reverse-mode AD (grad, vjp), the custom backward rule is used instead of tracing through the forward function. Under forward-mode AD (jvp) or outside AD, the forward function is traced normally.

Higher-order derivatives. The backward function runs outside the inner handler's continuation, so its Nx operations are traced by enclosing AD handlers. This means grad (fun x -> grad (custom_vjp_fn) x) x works correctly.

Sourceval custom_vjp : fwd:(('a, 'b) Nx.t -> ('c, 'd) Nx.t * 'res) -> bwd:('res -> ('c, 'd) Nx.t -> ('a, 'b) Nx.t) -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t

custom_vjp ~fwd ~bwd x computes fwd x with a custom VJP rule.

fwd returns (y, residuals) where y is the output and residuals is auxiliary data saved for the backward pass (e.g. intermediate values needed by the backward rule). residuals is not differentiated.

bwd residuals g receives the output cotangent g and returns the input cotangent. It is only called under reverse-mode AD (grad, vjp); under forward-mode AD (jvp) or outside AD, fwd is traced normally instead.

Sourceval custom_vjps : fwd:(('a, 'b) Nx.t list -> ('c, 'd) Nx.t * 'res) -> bwd:('res -> ('c, 'd) Nx.t -> ('a, 'b) Nx.t list) -> ('a, 'b) Nx.t list -> ('c, 'd) Nx.t

custom_vjps ~fwd ~bwd xs is like custom_vjp for functions with multiple inputs. bwd must return a list of the same length as xs.

Sourceval custom_jvp : fwd:(('a, 'b) Nx.t -> ('c, 'd) Nx.t) -> jvp_rule:(('a, 'b) Nx.t -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t * ('c, 'd) Nx.t) -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t

custom_jvp ~fwd ~jvp_rule x computes fwd x with a custom JVP rule.

jvp_rule primal tangent receives the primal input and its tangent, and returns (y, dy) where y is the primal output and dy is its tangent. It is only called under forward-mode AD (jvp); under reverse-mode AD (grad, vjp) or outside AD, fwd is traced normally instead.

Sourceval custom_jvps : fwd:(('a, 'b) Nx.t list -> ('c, 'd) Nx.t) -> jvp_rule: (('a, 'b) Nx.t list -> ('a, 'b) Nx.t list -> ('c, 'd) Nx.t * ('c, 'd) Nx.t) -> ('a, 'b) Nx.t list -> ('c, 'd) Nx.t

custom_jvps ~fwd ~jvp_rule xs is like custom_jvp for functions with multiple inputs. jvp_rule primals tangents receives a list of primals and their tangents, and returns (y, dy).

Gradient checking

Compare autodiff gradients against finite-difference approximations. Useful for testing custom operations.

Sourcetype method_ = [
  1. | `Central
  2. | `Forward
  3. | `Backward
]

The type for finite difference methods.

  • `Central(f(x+h) - f(x-h)) / 2h. Most accurate, requires two evaluations per element.
  • `Forward(f(x+h) - f(x)) / h.
  • `Backward(f(x) - f(x-h)) / h.
Sourceval finite_diff : ?eps:float -> ?method_:method_ -> (('a, 'b) Nx.t -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t -> ('a, 'b) Nx.t

finite_diff f x is the gradient of scalar-valued f at x approximated by finite differences.

eps defaults to 1e-4. method_ defaults to `Central.

Sourceval finite_diff_jacobian : ?eps:float -> ?method_:method_ -> (('a, 'b) Nx.t -> ('c, 'd) Nx.t) -> ('a, 'b) Nx.t -> ('c, 'd) Nx.t

finite_diff_jacobian f x is the Jacobian of f at x approximated by finite differences.

eps defaults to 1e-4. method_ defaults to `Central.

Sourcetype gradient_check_result = {
  1. max_abs_error : float;
    (*

    Largest absolute error across all elements.

    *)
  2. max_rel_error : float;
    (*

    Largest relative error across all elements.

    *)
  3. mean_abs_error : float;
    (*

    Mean absolute error.

    *)
  4. mean_rel_error : float;
    (*

    Mean relative error.

    *)
  5. failed_indices : (int array * float * float * float) list;
    (*

    (index, autodiff, finite_diff, abs_error) for each failed element.

    *)
  6. passed : bool;
    (*

    true iff no element exceeded the tolerances.

    *)
  7. num_checked : int;
    (*

    Number of elements checked.

    *)
  8. num_failed : int;
    (*

    Number of elements that exceeded tolerances.

    *)
}

The type for gradient check results.

Sourceval check_gradient : ?eps:float -> ?rtol:float -> ?atol:float -> ?verbose:bool -> ?check_indices:int list option -> ?method_:[ `Central | `Forward | `Backward ] -> ((float, 'a) Nx.t -> ('b, 'c) Nx.t) -> (float, 'a) Nx.t -> [ `Pass of gradient_check_result | `Fail of gradient_check_result ]

check_gradient f x compares the autodiff gradient of f at x against a finite-difference approximation.

An element passes when abs_error <= atol or rel_error <= rtol.

  • eps defaults to 1e-4.
  • rtol defaults to 2e-3.
  • atol defaults to 2e-3.
  • verbose defaults to false. When true, prints per-element failures and a summary to standard output.
  • check_indices defaults to None (check all elements). When Some indices, only the listed flat indices are checked.
  • method_ defaults to `Central.

See also check_gradients.

Sourceval check_gradients : ?eps:float -> ?rtol:float -> ?atol:float -> ?verbose:bool -> ?method_:[ `Central | `Forward | `Backward ] -> ((float, 'a) Nx.t list -> ('b, 'c) Nx.t) -> (float, 'a) Nx.t list -> [ `Pass of gradient_check_result list | `Fail of gradient_check_result list ]

check_gradients f xs is like check_gradient for functions with multiple inputs. Returns one gradient_check_result per input tensor.

Optional parameters have the same defaults as check_gradient.

Vectorising map

Map a computation over a batch dimension. vmap transforms a function that operates on single examples into one that operates on batches, without the user writing explicit batch loops.

Sourcetype axis_spec =
  1. | Map of int
    (*

    Map over the axis at this index.

    *)
  2. | NoMap
    (*

    Do not map; broadcast the input as-is.

    *)

The type for per-input axis specifications.

Sourcetype 'a in_axes_spec =
  1. | Single of axis_spec
    (*

    Apply to all inputs.

    *)
  2. | Container of 'a
    (*

    Per-input specifications.

    *)

The type for input axis specifications.

Sourcetype 'a out_axes_spec =
  1. | OutSingle of int option
    (*

    Stack outputs along this axis (None to discard).

    *)
  2. | OutContainer of 'a
    (*

    Per-output specifications.

    *)

The type for output axis specifications.

Sourceval vmap : ?in_axes:'a in_axes_spec -> ?out_axes:'b out_axes_spec -> ?axis_name:string -> ?axis_size:int -> (('c, 'd) Nx.t -> ('e, 'f) Nx.t) -> ('c, 'd) Nx.t -> ('e, 'f) Nx.t

vmap f x is a vectorised version of f applied to x.

  • in_axes defaults to Single (Map 0).
  • out_axes defaults to OutSingle (Some 0).
  • axis_name is an optional label for the mapped axis (used in error messages).
  • axis_size overrides the batch size inferred from the input shape. Required when all inputs use NoMap.

See also vmaps.

Sourceval vmaps : ?in_axes:Rune__.Vmap.axis_spec list -> ?out_axes:'b Rune__.Vmap.out_axes_spec -> ?axis_name:string -> ?axis_size:int -> (('c, 'd) Nx.t list -> ('e, 'f) Nx.t) -> ('c, 'd) Nx.t list -> ('e, 'f) Nx.t

vmaps f xs is like vmap for functions with multiple inputs. Each element of in_axes corresponds to one input in xs.

in_axes defaults to Map 0 for every input.

Debugging

Sourceval debug : ('a -> 'b) -> 'a -> 'b

debug f x applies f to x under a tracing handler that prints every tensor operation, its inputs, and its outputs to standard output.