package caisar

  1. Overview
  2. Docs

Module Nir.NodeSource

Nodes descriptions

A node is composed of

  • a unique id of type int
  • a node description of type descr

descr describes several operations. When an operation shares the same name as an ONNX operation, it follows the standard defined in the ONNX IR v8 and ONNX Opset v13 standards, described here: https://onnx.ai/onnx/operators/index.html.

Nodes only require their inputs: it is assumed that a node only returns one value.

Sourcetype ty =
  1. | BFloat16
  2. | Float16
  3. | Float
  4. | UInt8
  5. | Int8
  6. | Int32
  7. | Int64
Sourcetype descr =
  1. | Constant of {
    1. data : Gentensor.t;
    }
    (*

    A constant tensor, used to store non-varying parameters during inference.

    *)
  2. | Add of {
    1. input1 : t;
    2. input2 : t;
    }
  3. | Sub of {
    1. input1 : t;
    2. input2 : t;
    }
  4. | Mul of {
    1. input1 : t;
    2. input2 : t;
    }
  5. | Div of {
    1. input1 : t;
    2. input2 : t;
    }
  6. | Sum of {
    1. input : t;
    }
  7. | Matmul of {
    1. input1 : t;
    2. input2 : t;
    }
  8. | QLinearMatMul of {
    1. inputA : t;
    2. inputA_scale : t;
    3. inputA_zero_point : t;
    4. inputB : t;
    5. inputB_scale : t;
    6. inputB_zero_point : t;
    7. y_scale : t;
    8. y_zero_point : t;
    }
  9. | Gemm of {
    1. inputA : t;
    2. inputB : t;
    3. inputC : t Base.option;
    4. alpha : Base.float;
    5. beta : Base.float;
    6. transA : Base.int;
    7. transB : Base.int;
    }
  10. | QGemm of {
    1. inputA : t;
    2. inputA_scale : t;
    3. inputA_zero_point : t;
    4. inputB : t;
    5. inputB_scale : t;
    6. inputB_zero_point : t;
    7. inputC : t Base.option;
    8. y_scale : t Base.option;
    9. y_zero_point : t Base.option;
    10. alpha : Base.float;
    11. transA : Base.int;
    12. transB : Base.int;
    }
    (*

    Not an ONNX operator of the default domain. Documentation at:

    • https://xadupre.github.io/draft/inference/operators/onnx_commicrosoft_QGemm.html
    • https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftqgemm
    *)
  11. | LogSoftmax
  12. | Sigmoid of {
    1. input : t;
    }
  13. | ReLu of {
    1. input : t;
    }
  14. | Softmax of {
    1. input : t;
    2. axis : Base.int;
    }
  15. | Transpose of {
    1. input : t;
      (*

      Called "data" in ONNX documentation: https://onnx.ai/onnx/operators/onnx__Transpose.html .

      *)
    2. perm : Base.int Base.list;
    }
  16. | Squeeze of {
    1. data : t;
    2. axes : t Base.option;
    }
  17. | MaxPool
  18. | Conv
  19. | Reshape of {
    1. input : t;
    2. shape : t;
    }
  20. | Flatten of {
    1. input : t;
    2. axis : Base.int;
    }
  21. | Identity of {
    1. input : t;
    }
  22. | Input of {
    1. shape : Shape.t;
    }
  23. | RW_Linearized_ReLu
  24. | Concat of {
    1. inputs : t Base.list;
    2. axis : Base.int;
    }
  25. | Gather of {
    1. input : t;
    2. indices : t;
    3. axis : Base.int;
    }
  26. | ReduceSum of {
    1. input : t;
    2. axes : t Base.option;
    3. keepdims : Base.int;
    4. noop_with_empty_axes : Base.int;
    }
  27. | GatherND of {
    1. data : t;
    2. indices : t;
    3. batch_dims : Base.int;
    }
  28. | RandomNormal of {
    1. dtype : Base.int;
    2. mean : Base.float;
    3. scale : Base.float;
    4. seed : Base.float;
    5. shape : Base.int Base.array;
    }
  29. | Abs of {
    1. input : t;
    }
  30. | Log of {
    1. input : t;
    }
  31. | Exp of {
    1. input : t;
    }
  32. | Sign of {
    1. input : t;
    }
  33. | ArgMax of {
    1. input : t;
    2. axis : Base.int;
    3. keepdims : Base.bool;
    }
  34. | Pow of {
    1. input1 : t;
    2. input2 : t;
    }
  35. | QuantizeLinear of {
    1. x : t;
    2. y_scale : t;
    3. y_zero_point : t Base.option;
    4. axis : Base.int;
    }
  36. | DequantizeLinear of {
    1. x : t;
    2. x_scale : t;
    3. x_zero_point : t Base.option;
    4. axis : Base.int;
    }
Sourceand t = private {
  1. id : Base.int;
  2. descr : descr;
  3. shape : Shape.t;
  4. ty : ty;
    (*

    Describes the shape of the result of the node computation.

    *)
}
Sourceval equal : t -> t -> Base.bool
include Base.Hashtbl.Key.S with type t := t
Sourceval compare : t -> t -> int
Sourceval sexp_of_t : t -> Sexplib0.Sexp.t
Sourceval hash : t -> int

Two ts that compare equal must have equal hashes for the hashtable to behave properly.

include Base.Comparator.S with type t := t
Sourcetype comparator_witness
Sourceval create : descr -> t

create descr returns a value of type node with proper indexing and the shape according to the ONNX semantic.

Sourceval reducesum_int : ?encode:Base.bool -> t -> t

reducesum_int n sums all elements of the node n.

Sourceval reducesum_with_ignored_indices : t -> Base.int Base.list -> t

reducesum_with_ignored_indices n idx sums all elements of the node n, except elements of indices in idx.

Sourceval gather_int : ?encode:Base.bool -> t -> Base.int -> t

gather_int n i gathers the ith element of the node n.

Sourceval gather_ints : ?encode:Base.bool -> t -> Base.int Base.list -> t

gather_ints n l gathers the ith element of the node n for each element i of l.

Sourceval map : (t -> t) -> t -> t

map f n replace the direct inputs i of n by f i

Sourceval map_rec : (t -> t) -> t -> t

map_rec f n replace top-bottom the nodes i accessible from n by f i

Sourceval replace_input : (Base.unit -> t) -> t -> t

replace_input f n replace the input in n by f ()

Sourceval preds : t -> t Base.list

Direct predecessors of a t.

Sourceval iter_rec : (t -> Base.unit) -> t -> Base.unit

Iterate on the predecessors of a t and itself. Repect topological order.

Sourceval compute_shape : t -> Shape.t
Sourceval (+) : t -> t -> t
Sourceval (*) : t -> t -> t
Sourceval (-) : t -> t -> t
Sourceval (^^) : t -> t -> t
Sourceval (+.) : t -> Base.float -> t
Sourceval (*.) : t -> Base.float -> t
Sourceval (^.) : t -> Base.float -> t
Sourceval mul_float : t -> Base.float -> t
Sourceval div_float : ?encode:Base.bool -> t -> Base.float -> t
Sourceval concat_0 : t Base.list -> t
Sourceval reshape : Shape.t -> t -> t
Sourceval sign : t -> t
Sourceval transpose_op : ?perm:Base.int Base.list -> t -> t

transpose_op perm n transposes node n shape by following perm. For instance, if perm = 0;2;1 and n is of shape 1;2;3, the output shape will be of shape 1;3;2. perm is assumed to be of the same shape as n.

Sourceval (**) : t -> t -> t

n1 ** n2 is the Matmul between n1 and n2.

Sourceval add_one_dimension : t -> t

add_one_dimension n is a node that is a copy of n but adds one dimension at the front of its shape. So a tensor [[a,b],[c,d]] of shape (2,2) would be rewritten into tensor [[[a,b],[c,d]]] of shape (1,2,2).

Sourceval add_one_dimension_back : t -> t

add_one_dimension_back n is a node that is a copy of n but adds one dimension at the back of its shape. So a tensor [[a,b],[c,d]] of shape (2,2) would be rewritten into tensor [[[a],[b]],[[c],[d]]] of shape (2,2,1).

Sourceval sum_list : ?shp:Shape.t -> t Base.list -> t

sum_list shp ns is a node corresponding to the sum of the nodes in ns. If ns is empty, this returns a tensor of shape shp filled with 0s. By default, shp is a single float.

Sourceval partial_dot_product : ?shp:Shape.t -> t Base.array -> t Base.array -> Base.int -> Base.int -> t

partial_dot_product shp arr1 arr2 first last where arr1 = [\|n11, n12, ..., n1k1\|] and arr2 = [\|n21, n22, ..., n2k2\|] is a node corresponding to (n1first * n2first) + (n1first + 1 * n2first + 1) + ... + (n1last - 1 * n2last - 1) if this exists. It is assumed that arr1 and arr2 contain tensors with same shape. Edge cases include:

  • if last > length n1 or last > length n2, then fails
  • if last >= first, then returns a tensor where all values are initialized to 0. The shape of this tensor is determined using the following order:
  1. if length arr1 <> 0 then use the shape of arr1.(0)
  2. if length arr2 <> 0 then use the shape of arr2.(0)
  3. if shp <> None, then use shp
  4. otherwise, fails

transpose perm id is the position of the component at position id when the permutation perm is applied as the result of a tensor transposition. For instance, if id = [\|10; 20; 30\|] and \|perm = [2;0;1]\|, then transpose perm id will equal [\|30; 10; 20\|]. The empty permutation is represented by [] and is interpreted as [r-1; r-2; ...; 1; 0] (following <https://onnx.ai/onnx/operators/onnx__Transpose.html#transpose-23>); otherwise it is assumed that perm is a permutation of [0; 1; ...; r-1].

untranspose is the reverse of transpose so that untranspose perm @@ transpose perm id equals id.

flatten shp axis id computes the position of the component at position id when the flattening on axis is performed over shape shp. Following the definition of Flatten <https://onnx.ai/onnx/operators/onnx__Flatten.html#flatten-23>, axis can be negative in which case it computes from the end. For instance, if the shape is 10 * 10 * 10 * 10 * 10 (rank 5), axis = 3, and id = [\|1; 2; 3; 4; 5\|], then the result will be [\|123; 45\|] (computed as [\| 3 + (10 * ( 2 + 10 * 1)); 5 + 10 * 4) \|]).

unflatten is the reverse of flatten so that unflatten sh axis @@ flatten sh axis id equals id (for correct inputs id).

Sourceval encode_qgemm : descr -> descr

Encode QGemm operator in terms of DequantizeLinear, QuantizeLinear and Gemm operators.