Legend:
Library
Module
Module type
Parameter
Class
Class type
Library
Module
Module type
Parameter
Class
Class type
module Tn = Arrayjit.Tnode
module Nd = Arrayjit.Ndarray
module NTDSL = Operation.NTDSL
module Asgns = Arrayjit.Assignments
module Idx = Arrayjit.Indexing
module Utils = Arrayjit.Utils
module Rand = Arrayjit.Rand.Lib
module type Backend_type = Arrayjit.Backends.Backend
module Debug_runtime = Arrayjit.Utils.Debug_runtime
module CDSL : sig ... end
module IDX : sig ... end
val debug_rt : (module Minidebug_runtime.Debug_runtime)
val run : 'a Arrayjit.Backends.routine -> Base.unit
val fresh_backend :
?backend_name:Base.String.t ->
unit ->
(module Arrayjit.Backends.Backend)
Reinitializes a backend selected via a global backend
flag.
val is_param : Tensor.t -> Base.bool
val get_params : Tensor.t -> (Tensor.t, Tensor.comparator_witness) Base.Set.t
val save_params : Tensor.t -> unit
val restore_params : Tensor.t -> unit
val set_on_host : Tn.memory_type -> Tn.t -> unit
val set_materialized : Tn.t -> unit
val set_hosted : Tn.t -> unit
Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.
type updaten = {
loss : Tensor.t;
label : Base.string;
params : (Tensor.t, Tensor.comparator_witness) Base.Set.t;
fwd_bprop : Asgns.t;
}
Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel
is true (false by default), sets the parameters and their gradients as "non-local" (on-device).
val sgd_one :
learning_rate:Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
Tensor.t ->
Asgns.t
See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py
val sequential_loop :
f:(unit -> Base.unit) ->
(Idx.static_symbol * int Base.ref) list ->
Base.unit
All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val round_robin :
(unit -> 'a) Base.Array.t ->
Idx.jitted_bindings Base.Array.t ->
(Idx.static_symbol * Base.int Base.ref) list ->
sync:(Base.int -> Base.unit) ->
Base.unit
Distributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync
is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.
val round_robin_dry_run :
num_devices:Base__Int.t ->
(Idx.static_symbol * int Base.ref) list ->
dry_sync:(Base__Int.t -> Base.unit) ->
Base.unit
val set_virtual : Tn.t -> unit
val every_non_literal_on_host : Tensor.t -> Base.unit
val all_host_to_device :
(module Backend_type with type context = 'context) ->
'context ->
Tensor.t ->
Base.unit
val all_device_to_host :
(module Backend_type with type context = 'context) ->
'context ->
Tensor.t ->
Base.unit
val sync_run :
?looping:(unit -> Base.unit) ->
(module Backend_type with type context = 'context) ->
'context Arrayjit__Backends.routine ->
Tensor.t ->
Base.unit
Executes the jitted code and copies arrays embedded in the given tenosor from and to host, synchronizes before copying to host. If looping
is provided, loops over bindings and executes the given function inside the loop after a run. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
module Lazy = Utils.Lazy
val parallel_update :
(module Backend_type with type context = 'context) ->
grad_updates:'context Arrayjit__Backends.routine Base.array ->
sgd_update:'context Arrayjit__Backends.routine ->
post_sync:(num_synced_devices:Base.int -> Base.unit) ->
updaten ->
Base.unit ->
Base.unit
Performs one optimization step, potentially in parallel (if grad_updates
are compiled for different devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of grad_updates
in a round-robin fashion, and performs the following synchronization each time all grad_updates
have been called:
1. merges all gradients into the device of grad_updates.(0)
, 2. calls sgd_update
, 3. copies all parameters from the grad_updates.(0)
device to the other devices, if needed, 4. calls post_sync
with the number of devices synced since the previous sync.
All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val example_train_loop :
?disable_rootness_check:bool ->
name:Base.String.t ->
seed:Base.int ->
batch_size:Base__Int.t ->
init_lr:Base.float ->
?lr_schedule:
(batch_n:Idx.static_symbol -> step_n:Idx.static_symbol -> Tensor.t) ->
num_devices:Base__Int.t ->
data_len:Base__Int.t ->
epochs:Base__Int.t ->
inputs:(b:Base__Int.t list -> Tensor.t) ->
outputs:(b:Base__Int.t list -> Tensor.t) ->
model:(Tensor.t -> Tensor.t) ->
loss_fn:(output:Tensor.t -> expectation:Tensor.t -> Tensor.t) ->
weight_decay:Base.float ->
?per_batch_callback:
(at_batch:Base.int ->
at_step:Base.int ->
learning_rate:Base.float ->
batch_loss:Base.float ->
epoch_loss:float ->
unit) ->
?per_epoch_callback:
(at_step:Base.int ->
at_epoch:int ->
learning_rate:Base.float ->
epoch_loss:float ->
unit) ->
(module Arrayjit.Backends.Backend) ->
unit ->
Tensor.t
* Tensor.t
* Tensor.t
* (Base.float Base.array ->
Base.float Base.array)
* Base.float list
* float list
* Base.float list
val forward_and_forget :
?disable_rootness_check:Base.bool ->
(module Backend_type with type context = 'context) ->
'context ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
Tensor.t ->
Base.unit