Library
Module
Module type
Parameter
Class
Class type
type tn = Arrayjit.Tnode.t
type asgns = Arrayjit.Assignments.t
type init_op = Arrayjit.Ops.init_op
type fetch_op = Arrayjit.Assignments.fetch_op
type projections = Arrayjit.Indexing.projections
type diff = {
grad : tn;
zero_grads : asgns;
Prepares for backpropagation. Always compile as: Seq (zero_grads, backprop)
.
backprop : asgns;
Backpropagates for the tensor and its descendants; which typically means adding partial gradients to the gradient tensor of the subtensors, then for sub-subtensors etc.
*)}
type t = {
forward : asgns;
diff : diff Base.option;
id : Base.int;
Same as value.id
.
value : tn;
shape : Shape.t;
The eventual shape of t.value
and t.diff.grad
, incorporating the current state of shape inference.
children : subtensor Base.list;
}
Information needed for compositional code generation.
val sexp_of_t : t -> Sexplib0.Sexp.t
val sexp_of_subtensor : subtensor -> Sexplib0.Sexp.t
val comparator : (t, comparator_witness) Base.Comparator.t
val is_fwd_root : t -> Base.bool
val remove_fwd_root : t -> Base.unit
val is_bprop_root : t -> Base.bool
val remove_bprop_root : t -> Base.unit
val default_value_prec : Arrayjit.Ops.prec Base.ref
val default_grad_prec : Arrayjit.Ops.prec Base.ref
exception Session_error of Base.string * t Base.option
val raw_binop :
initialize_neutral:Base.bool ->
accum:Arrayjit.Ops.binop ->
t:t ->
lhs_is_grad:Base.bool ->
op:Arrayjit.Ops.binop ->
t1:t ->
rhs1_is_grad:Base.bool ->
rhs1_is_merge:Base.bool ->
t2:t ->
rhs2_is_grad:Base.bool ->
rhs2_is_merge:Base.bool ->
logic:Shape.compose_type ->
asgns
val raw_unop :
initialize_neutral:Base.bool ->
accum:Arrayjit.Ops.binop ->
t:t ->
lhs_is_grad:Base.bool ->
op:Arrayjit.Ops.unop ->
t1:t ->
rhs_is_grad:Base.bool ->
rhs_is_merge:Base.bool ->
logic:Shape.transpose_type ->
asgns
val is_prohibit_grad : grad_spec -> Base.bool
val op :
label:Base.string Base.list ->
?compose_op:Shape.compose_type ->
?transpose_op:Shape.transpose_type ->
?init_op:init_op ->
op_asn:(v:tn -> projections:projections Base.Lazy.t -> asgns) ->
grad_asn:(v:tn -> g:tn -> projections:projections Base.Lazy.t -> asgns) ->
?grad_spec:grad_spec ->
(debug_name:Base.string -> id:Base.int -> Shape.t) ->
t Base.list ->
t
val unop :
label:Base.string Base.list ->
?transpose_op:Shape.transpose_type ->
op_asn:(v:tn -> t1:t -> projections:projections Base.Lazy.t -> asgns) ->
grad_asn:
(v:tn -> g:tn -> t1:t -> projections:projections Base.Lazy.t -> asgns) ->
?grad_spec:grad_spec ->
t ->
t
val term :
label:Base.string Base.list ->
grad_spec:grad_spec ->
?batch_dims:Base.int Base.list ->
?input_dims:Base.int Base.list ->
?output_dims:Base.int Base.list ->
?batch_axes:(Base.string * Base.int) Base.list ->
?input_axes:(Base.string * Base.int) Base.list ->
?output_axes:(Base.string * Base.int) Base.list ->
?deduced:Shape.deduce_within_shape ->
?init_op:init_op ->
?fetch_op:(v:tn -> fetch_op) ->
Base.unit ->
t
A terminal: a constant, a parameter, an input of the model. The semantics of shape specification is the same as in Shape.make
, and by default the shape will be inferred.
val number :
?label:Base.string Base.list ->
?axis_label:Base.string ->
?grad_spec:grad_spec ->
Base.float ->
t
A number: a tensor with a single axis of one dimension, initialized to the given value. grad_spec
is by default Prohibit_grad
.
val ndarray :
?label:Base.string Base.list ->
?grad_spec:grad_spec ->
?batch_dims:Base.int Base.list ->
?input_dims:Base.int Base.list ->
?output_dims:Base.int Base.list ->
?batch_axes:(Base.string * Base.int) Base.list ->
?input_axes:(Base.string * Base.int) Base.list ->
?output_axes:(Base.string * Base.int) Base.list ->
?strict:Base.bool ->
Base.float Base.array ->
t
A tensor with an explicit shape, initialized to the given values. Omitted shape rows default to no axes. grad_spec
is by default Prohibit_grad
. If strict
is true
(the default), the given values must fill the tensor's value
node precisely; otherwise, the values will be looped over to populate the value
node.
val param :
?more_label:Base.string Base.list ->
?input_dims:Base.int Base.list ->
?output_dims:Base.int Base.list ->
?input_axes:(Base.string * Base.int) Base.list ->
?output_axes:(Base.string * Base.int) Base.list ->
?deduced:Shape.deduce_within_shape ->
?strict:Base.bool ->
?values:Base.float Base.array ->
Base.string ->
t
val non_and_embedded_nodes :
t ->
(t, comparator_witness) Base.Set.t * (t, comparator_witness) Base.Set.t
A forward root is a tensor that is not (currently) used to compute another tensor. consume_forward_code t
ensures t
is a forward root, removes it from forward roots, and checks that there are no other forward roots for tensors with children.
A backprop root is a tensor with a gradient that is not (currently) receiving gradients from another tensor. I.e. it is not currently used to compute a tensor with a gradient. consume_backprop_code t
ensures t
is a backprop root, removes it from backprop roots, and checks that there are no other backprop roots for tensors with children.
val header : t -> Base.string
Converts ID, label and the dimensions of a node to a string.
val log_debug_info : from_log_level:Base.int -> t -> Base.unit
Logs debug information about the tensor on the default ppx_minidebug runtime.
type array_print_style = [
| `Default
The inner rectangles comprise both an input and an output axis, if available. Similarly, the outer rectangle comprises a second-from-end input axis and a second-from-end output axis, if available. At least one batch axis is output, when available. The axes that couldn't be output are printed at position/dimension 0
.
| `N5_layout of Base.string
The string should provide exclusively non-negative integer pseudo-labels. The numbers 0
-4
represent the priorities of the axes to be printed out, where the priorities correspond to, from highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical direction of the outer rectangle, repetition (see also Node.pp_print
). The numbers n >= 5
stand for the actual positions n - 5
within the corresponding axes.
| `Label_layout of (Base.string * Base.int) Base.list
The association from axis labels to integers. The negative numbers -5
to -1
represent the priorities of the axes to be printed out, where the priorities correspond to, from highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical direction of the outer rectangle, repetition (as above). The numbers n >= 0
stand for the actual positions within the corresponding axes. Unspecified axes are printed at position 0
.
| `Inline
The tensors are printed linearly, in a bracketed manner, optionally prefixed with the labels specification. Note that the syntax causes ambiguity for 1-dimensional input axes (underscores are used for axes without explicit labels); when there is a 1-dimensional input axis, we output the labels specification even if there are no axis labels as a way to display the number of axes. The axis nesting is right-to-left (rightmost is innermost). The input axes are innermost and the batch axes outermost. The input axes use ,
as a separator and ()
as axis delimiters, but the delimiter for the outermost (i.e. leftmost) axis is omitted. The output axes use ;
as a separator and []
as axis delimiters (obligatory). The batch axes use ;
as a separator and [||]
as axis delimiters (obligatory).
]
We print out up to 5 axes when printing a tensor, as a grid (outer rectangle) of (inner) rectangles, possibly repeated (screens).
val to_printbox :
?single_node:Base.bool ->
?entries_per_axis:Base.int ->
?with_id:Base.bool ->
?with_shape:Base.bool ->
?with_value:Base.bool ->
with_grad:Base.bool ->
depth:Base.int ->
t ->
PrintBox.t
val print :
with_grad:Base.bool ->
with_code:Base.bool ->
?force:Base.bool ->
?with_low_level:Base.bool ->
array_print_style ->
t ->
Base.unit
val print_forward_roots :
with_grad:Base.bool ->
with_code:Base.bool ->
array_print_style ->
Base.unit
val print_tree :
?entries_per_axis:Base.int ->
?with_backend_info:Base.bool ->
?with_id:Base.bool ->
?with_shape:Base.bool ->
?with_value:Base.bool ->
with_grad:Base.bool ->
depth:Base.int ->
t ->
Base.unit
val debug_name : t -> Base.string
val value_1d_points :
?from_axis:Base.int ->
xdim:Base.int ->
t ->
Base.float Base.array
val value_2d_points :
?from_axis:Base.int ->
xdim:Base.int ->
ydim:Base.int ->
t ->
(Base.float * Base.float) Base.array
val grad_1d_points :
?from_axis:Base.int ->
xdim:Base.int ->
t ->
Base.float Base.array
val grad_2d_points :
?from_axis:Base.int ->
xdim:Base.int ->
ydim:Base.int ->
t ->
(Base.float * Base.float) Base.array
val set_value : t -> Base.int Base.array -> Base.float -> Base.unit
val get_value : t -> Base.int Base.array -> Base.float
val set_grad : t -> Base.int Base.array -> Base.float -> Base.unit
val get_grad : t -> Base.int Base.array -> Base.float
val set_values : t -> Base.float Base.array -> Base.unit
val get_values : t -> Base.float Base.array