package neural_nets_lib

  1. Overview
  2. Docs
module Tn = Arrayjit.Tnode
module Nd = Arrayjit.Ndarray
module NTDSL = Operation.NTDSL
module Asgns = Arrayjit.Assignments
module Idx = Arrayjit.Indexing
module Utils = Arrayjit.Utils
module Rand = Arrayjit.Rand.Lib
module type Backend_type = Arrayjit.Backends.Backend
module Debug_runtime = Arrayjit.Utils.Debug_runtime
val _get_local_debug_runtime : unit -> (module Minidebug_runtime.Debug_runtime)
module CDSL : sig ... end
module IDX : sig ... end
val run : 'a BT.routine -> Base.unit
val is_param : Tensor.t -> Base.bool
val save_params : Tensor.t -> unit
val restore_params : Tensor.t -> unit
val set_on_host : Tn.memory_type -> Tn.t -> unit
val set_materialized : Tn.t -> unit
val set_hosted : Tn.t -> unit
val forward : ?disable_rootness_check:bool -> Tensor.t -> Asgns.t

Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.

type updaten = {
  1. loss : Tensor.t;
  2. label : Base.string;
  3. params : (Tensor.t, Tensor.comparator_witness) Base.Set.t;
  4. fwd_bprop : Asgns.t;
}
val grad_update : ?disable_rootness_check:bool -> ?setup_for_parallel:bool -> Tensor.t -> updaten

Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel is true (false by default), sets the parameters and their gradients as "non-local" (on-device).

val sgd_one : learning_rate:Tensor.t -> ?momentum:Base.Float.t -> ?weight_decay:Base.float -> ?nesterov:bool -> Tensor.t -> Arrayjit.Assignments.t

See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py

val sgd_update : learning_rate:Tensor.t -> ?momentum:Base.Float.t -> ?weight_decay:Base.float -> ?nesterov:bool -> updaten -> Asgns.t
val sequential_loop : f:(unit -> Base.unit) -> (Idx.static_symbol * int Base.ref) list -> Base.unit

All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

val round_robin : (unit -> 'a) Base.Array.t -> Idx.lowered_bindings Base.Array.t -> (Idx.static_symbol * Base.int Base.ref) list -> sync:(Base.int -> Base.unit) -> Base.unit

Distributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.

val round_robin_dry_run : num_devices:int -> (Idx.static_symbol * int Base.ref) list -> dry_sync:(int -> Base.unit) -> Base.unit
val set_virtual : Tn.t -> unit
val every_non_literal_on_host : Tensor.t -> Base.unit
val all_host_to_device : (module Backend_type with type context = 'context) -> 'context -> Tensor.t -> Base.unit
val all_device_to_host : (module Backend_type with type context = 'context) -> 'context -> Tensor.t -> Base.unit
val needs_prior_context : Tensor.t -> Tensor.tn Base.List.t
val sync_run : ?looping:(unit -> Base.unit) -> (module Backend_type with type context = 'context) -> 'context Arrayjit.Backend_utils.Types.routine -> Tensor.t -> Base.unit

Executes the jitted code and copies arrays embedded in the given tenosor from and to host, synchronizes before copying to host. If looping is provided, loops over bindings and executes the given function inside the loop after a run. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

module Lazy = Utils.Lazy
val parallel_update : (module Backend_type with type context = 'context) -> grad_updates:'context Arrayjit.Backend_utils.Types.routine Base.array -> sgd_update:'context Arrayjit.Backend_utils.Types.routine -> copy_to_merge:bool -> post_sync:(num_synced_devices:Base.int -> Base.unit) -> updaten -> Base.unit -> Base.unit

Performs one optimization step, potentially in parallel (if grad_updates are compiled for different devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of grad_updates in a round-robin fashion, and performs the following synchronization each time all grad_updates have been called:

1. merges all gradients into the device of grad_updates.(0), 2. calls sgd_update, 3. copies all parameters from the grad_updates.(0) device to the other devices, if needed, 4. calls post_sync with the number of devices synced since the previous sync.

All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

val get_all_suggested_devices : ?max_num_devices:Base.int -> (module Backend_type with type device = 'device) -> 'device Base.array
val example_train_loop : ?disable_rootness_check:Base.bool -> seed:Base.int -> batch_size:int -> init_lr:Base.float -> ?lr_schedule: (batch_n:Idx.static_symbol -> step_n:Idx.static_symbol -> Tensor.t) -> ?copy_to_merge:bool -> ?max_num_devices:Base.int -> data_len:int -> epochs:int -> inputs:(b:int list -> Tensor.t) -> outputs:(b:int list -> Tensor.t) -> model:(Tensor.t -> Tensor.t) -> loss_fn:(output:Tensor.t -> expectation:Tensor.t -> Tensor.t) -> weight_decay:Base.float -> ?per_batch_callback: (at_batch:Base.int -> at_step:Base.int -> learning_rate:Base.float -> batch_loss:Base.float -> epoch_loss:float -> unit) -> ?per_epoch_callback: (at_step:Base.int -> at_epoch:int -> learning_rate:Base.float -> epoch_loss:float -> unit) -> ?prior_contexts:'context Base.array -> (module Backend_type with type context = 'context) -> unit -> Tensor.t * Tensor.t * Tensor.t * (Base.float Base.array -> Base.float Base.array) * Base.float list * float list * Base.float list
val forward_and_ctx : ?disable_rootness_check:Base.bool -> (module Backend_type with type context = 'context) -> 'context -> ?bindings:(Base.unit -> Base.unit) Idx.bindings -> Tensor.t -> 'context
val forward_and_forget : ?disable_rootness_check:Base.bool -> (module Backend_type with type context = 'a) -> 'a -> ?bindings:(Base.unit -> Base.unit) Idx.bindings -> Tensor.t -> Base.unit
OCaml

Innovation. Community. Security.