Legend:
Library
Module
Module type
Parameter
Class
Class type
Library
Module
Module type
Parameter
Class
Class type
module Tn = Arrayjit.Tnode
module Nd = Arrayjit.Ndarray
module NTDSL = Operation.NTDSL
module Asgns = Arrayjit.Assignments
module Idx = Arrayjit.Indexing
module Utils = Arrayjit.Utils
module Rand = Arrayjit.Rand.Lib
module BT = Arrayjit.Backend_utils.Types
module type Backend_type = Arrayjit.Backends.Backend
module Debug_runtime = Arrayjit.Utils.Debug_runtime
val _get_local_debug_runtime : unit -> (module Minidebug_runtime.Debug_runtime)
module CDSL : sig ... end
module IDX : sig ... end
val run : 'a BT.routine -> Base.unit
val get_params : Tensor.t -> (Tensor.t, Tensor.comparator_witness) Base.Set.t
val save_params : Tensor.t -> unit
val restore_params : Tensor.t -> unit
val set_on_host : Tn.memory_type -> Tn.t -> unit
val set_materialized : Tn.t -> unit
val set_hosted : Tn.t -> unit
val label_suffix : Base.String.t Base.List.t -> Base.String.t
Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.
type updaten = {
loss : Tensor.t;
label : Base.string;
params : (Tensor.t, Tensor.comparator_witness) Base.Set.t;
fwd_bprop : Asgns.t;
}
Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel
is true (false by default), sets the parameters and their gradients as "non-local" (on-device).
val sgd_one :
learning_rate:Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
Tensor.t ->
Arrayjit.Assignments.t
See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py
val sgd_update :
learning_rate:Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
updaten ->
Asgns.t
val sequential_loop :
f:(unit -> Base.unit) ->
(Idx.static_symbol * int Base.ref) list ->
Base.unit
All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val round_robin :
(unit -> 'a) Base.Array.t ->
Idx.lowered_bindings Base.Array.t ->
(Idx.static_symbol * Base.int Base.ref) list ->
sync:(Base.int -> Base.unit) ->
Base.unit
Distributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync
is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.
val round_robin_dry_run :
num_devices:int ->
(Idx.static_symbol * int Base.ref) list ->
dry_sync:(int -> Base.unit) ->
Base.unit
val set_virtual : Tn.t -> unit
val all_host_to_device :
(module Backend_type with type context = 'context) ->
'context ->
Tensor.t ->
Base.unit
val all_device_to_host :
(module Backend_type with type context = 'context) ->
'context ->
Tensor.t ->
Base.unit
val needs_prior_context : Tensor.t -> Tensor.tn Base.List.t
val sync_run :
?looping:(unit -> Base.unit) ->
(module Backend_type with type context = 'context) ->
'context Arrayjit.Backend_utils.Types.routine ->
Tensor.t ->
Base.unit
Executes the jitted code and copies arrays embedded in the given tenosor from and to host, synchronizes before copying to host. If looping
is provided, loops over bindings and executes the given function inside the loop after a run. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
module Lazy = Utils.Lazy
val parallel_update :
(module Backend_type with type context = 'context) ->
grad_updates:'context Arrayjit.Backend_utils.Types.routine Base.array ->
sgd_update:'context Arrayjit.Backend_utils.Types.routine ->
copy_to_merge:bool ->
post_sync:(num_synced_devices:Base.int -> Base.unit) ->
updaten ->
Base.unit ->
Base.unit
Performs one optimization step, potentially in parallel (if grad_updates
are compiled for different devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of grad_updates
in a round-robin fashion, and performs the following synchronization each time all grad_updates
have been called:
1. merges all gradients into the device of grad_updates.(0)
, 2. calls sgd_update
, 3. copies all parameters from the grad_updates.(0)
device to the other devices, if needed, 4. calls post_sync
with the number of devices synced since the previous sync.
All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val get_all_suggested_devices :
?max_num_devices:Base.int ->
(module Backend_type with type device = 'device) ->
'device Base.array
val example_train_loop :
?disable_rootness_check:Base.bool ->
seed:Base.int ->
batch_size:int ->
init_lr:Base.float ->
?lr_schedule:
(batch_n:Idx.static_symbol -> step_n:Idx.static_symbol -> Tensor.t) ->
?copy_to_merge:bool ->
?max_num_devices:Base.int ->
data_len:int ->
epochs:int ->
inputs:(b:int list -> Tensor.t) ->
outputs:(b:int list -> Tensor.t) ->
model:(Tensor.t -> Tensor.t) ->
loss_fn:(output:Tensor.t -> expectation:Tensor.t -> Tensor.t) ->
weight_decay:Base.float ->
?per_batch_callback:
(at_batch:Base.int ->
at_step:Base.int ->
learning_rate:Base.float ->
batch_loss:Base.float ->
epoch_loss:float ->
unit) ->
?per_epoch_callback:
(at_step:Base.int ->
at_epoch:int ->
learning_rate:Base.float ->
epoch_loss:float ->
unit) ->
?prior_contexts:'context Base.array ->
(module Backend_type with type context = 'context) ->
unit ->
Tensor.t
* Tensor.t
* Tensor.t
* (Base.float Base.array ->
Base.float Base.array)
* Base.float list
* float list
* Base.float list
val forward_and_ctx :
?disable_rootness_check:Base.bool ->
(module Backend_type with type context = 'context) ->
'context ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
Tensor.t ->
'context
val forward_and_forget :
?disable_rootness_check:Base.bool ->
(module Backend_type with type context = 'a) ->
'a ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
Tensor.t ->
Base.unit