Library
Module
Module type
Parameter
Class
Class type
Definition and properties of the syntax of labels specifications and einsum notation:
','
anywhere in the initial text, the multicharacter version is used,'>', '|', '-', ',', '=', ';'
.If labels_spec does not contain "|"
nor "->"
, each label is of the kind Output
. If the spec doesn't contain "|"
, labels to the left of "->"
are Input
and to the right Output
. Labels to the left of "|"
are Batch
, and between "|"
and "->"
are Input
.
The labels ".."ident".."
, "..."
(where ident
does not contain any of the special characters) are only allowed once for a kind. They are used to enable (in-the-middle) broadcasting for the axis kind in the einsum-related shape inference (like the ellipsis "..."
in numpy.einsum
), and are translated to row variables. The ellipsis "..."
is context dependent: in the batch row it is the same as "..batch.."
, in the input row the same as "..input.."
, in the output row the same as "..output.."
. When the same row variable is used in multiple rows, the corresponding broadcasted axes are matched pointwise in the resulting operation.
The label "_"
is a place-holder: it is not output to the resulting map but aligns the axes of other labels.
include Ppx_compare_lib.Equal.S with type t := t
val equal : t Base__Ppx_compare_lib.equal
val compare_deduce_within_shape :
deduce_within_shape ->
deduce_within_shape ->
Base.int
val sexp_of_deduce_within_shape : deduce_within_shape -> Sexplib0.Sexp.t
val deduce_within_shape_of_sexp : Sexplib0.Sexp.t -> deduce_within_shape
type compose_type =
| Pointwise_bin
NumPy-style broadcast matching batch, input and output axes, e.g. as in s1 + s2
.
| Compose
Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape of fun x -> s1(s2(x))
, or s1 * s2
where *
is the inner product (e.g. matrix multiply).
| Einsum of Base.string
The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications. Since OCANNL's extended einsum notation supports both axis variables and row variables, it makes other compose types redundant. The axis_labels
use pseudo-labels local to the notation, to line up the axes and row variables. The symmetric difference / disjunctive union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels.
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs".
*)val sexp_of_compose_type : compose_type -> Sexplib0.Sexp.t
val compose_type_of_sexp : Sexplib0.Sexp.t -> compose_type
val equal_compose_type : compose_type -> compose_type -> Base.bool
type transpose_type =
| Transpose
Swaps inputs and outputs of a shape, preserves batch axes.
*)| Pointwise_un
Preserves the shape.
*)| Permute of Base.string
The unary "einsum" syntax: RHS1=>LHS.
*)| Batch_slice of Arrayjit.Indexing.static_symbol
Removes the leftmost batch axis.
*)val equal_transpose_type : transpose_type -> transpose_type -> Base.bool
val sexp_of_transpose_type : transpose_type -> Sexplib0.Sexp.t
val transpose_type_of_sexp : Sexplib0.Sexp.t -> transpose_type
val make :
?batch_dims:Base.int Base.list ->
?input_dims:Base.int Base.list ->
?output_dims:Base.int Base.list ->
?batch_axes:(Base.string * Base.int) Base.list ->
?input_axes:(Base.string * Base.int) Base.list ->
?output_axes:(Base.string * Base.int) Base.list ->
?deduced:deduce_within_shape ->
debug_name:Base.string ->
id:Base.int ->
Base.unit ->
t
Creates a shape. id
should be the id the associated tensor (if any). At most one of the pairs batch_dims
, batch_axes
etc. should be given: if none, the corresponding row will be inferred. batch_axes
etc. provide labels for the dimensions of the corresponding axes. Note that these are dimensions labels and not axis labels: they need not be unique for a row, are inferred when provided, and must match whenever the axis sizes must match.
val to_string_hum :
?style:
[< `Axis_number_and_size
| `Axis_size
| `Only_labels Axis_size Only_labels ] ->
t ->
Base.string
type logic =
| Broadcast of compose_type * t * t
Matches the shapes for a binary operation.
For Broadcast (Einsum (ls1, ls2, ls3), s1, s2)
, the labels of s1
and s2
must match according to the ls1
, ls2
lineup, and the resulting shape inherits the labels according to the ls3
lineup.
| Transpose of transpose_type * t
Permutes the axes of a shape. One case of Transpose
is to swap inputs with outputs of s1
, hence the name.
| Terminal of Arrayjit.Ops.init_op
Extracts any available shape information from the initialization. E.g. for File_mapped fn
, opens the file fn
to check its length.
How to propagate shape updates and do the last update of Tensor.t.shape
when finalizing the tensor. Axes are broadcast-expanded on a bottom-up update to fit the incoming shape.
val sexp_of_logic : logic -> Sexplib0.Sexp.t
val logic_of_sexp : Sexplib0.Sexp.t -> logic
val hash_fold_update_id :
Ppx_hash_lib.Std.Hash.state ->
update_id ->
Ppx_hash_lib.Std.Hash.state
val hash_update_id : update_id -> Ppx_hash_lib.Std.Hash.hash_value
val sexp_of_update_id : update_id -> Sexplib0.Sexp.t
val update_id_of_sexp : Sexplib0.Sexp.t -> update_id
val get_update_id : Base.unit -> update_id
Data required for a shape inference update step. Ideally, an update should be performed at least twice, the second time after all the other relevant updates have been performed for the first time. In OCANNL, this is achieved by performing updates both as the tensors are constructed, and via lazy callbacks as the corresponding Arrayjit.Indexing
dimensions and projections are first accessed.
val sexp_of_update_step : update_step -> Sexplib0.Sexp.t
val update_step_of_sexp : Sexplib0.Sexp.t -> update_step
val to_dims : t -> Base.int Base.array
val propagate_shapes : update_step -> Base.unit
val derive_projections : update_step -> Arrayjit.Indexing.projections
Computes the indexing into subtensors given the shape information of a tensor. derive_projections
should only be invoked when the shapes are fully inferred already!
val backprop_ith_arg :
from_1:Base.int ->
Arrayjit.Indexing.projections ->
Arrayjit.Indexing.projections
val of_spec :
?deduced:deduce_within_shape ->
debug_name:Base.string ->
id:Base.int ->
Base.string ->
t
val default_display_indices : t -> Base.int Base.array
val to_labels : t -> Base.string Base.array
val axis_labels :
parsed_axis_labels ->
(Base.string, Base.int) Base.Either.t axis_map
val axis_labels_of_spec : Base.string -> parsed_axis_labels
val axis_map_to_dims_index : ?default:'a -> 'a axis_map -> 'a Base.array