package neural_nets_lib

  1. Overview
  2. Docs

Module Ocannl.TrainSource

Sourcemodule NTDSL = Operation.NTDSL
Sourcemodule Utils = Arrayjit.Utils
Sourcemodule type Backend_type = Arrayjit.Backends.Backend
Sourcemodule Debug_runtime = Arrayjit.Utils.Debug_runtime
Sourcemodule CDSL : sig ... end
Sourcemodule IDX : sig ... end
Sourceval fresh_backend : ?backend_name:Base.String.t -> unit -> (module Arrayjit.Backends.Backend)

Reinitializes a backend selected via a global backend flag.

Sourceval is_param : Tensor.t -> Base.bool
Sourceval save_params : Tensor.t -> unit
Sourceval restore_params : Tensor.t -> unit
Sourceval set_on_host : Tn.memory_type -> Tn.t -> unit
Sourceval set_materialized : Tn.t -> unit
Sourceval set_hosted : Tn.t -> unit
Sourceval forward : ?disable_rootness_check:bool -> Tensor.t -> Asgns.t

Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.

Sourcetype updaten = {
  1. loss : Tensor.t;
  2. label : Base.string;
  3. params : (Tensor.t, Tensor.comparator_witness) Base.Set.t;
  4. fwd_bprop : Asgns.t;
}
Sourceval grad_update : ?disable_rootness_check:bool -> ?setup_for_parallel:bool -> Tensor.t -> updaten

Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel is true (false by default), sets the parameters and their gradients as "non-local" (on-device).

Sourceval sgd_one : learning_rate:Tensor.t -> ?momentum:Base.Float.t -> ?weight_decay:Base.float -> ?nesterov:bool -> Tensor.t -> Asgns.t

See: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py

Sourceval sgd_update : learning_rate:Tensor.t -> ?momentum:Base.Float.t -> ?weight_decay:Base.float -> ?nesterov:bool -> updaten -> Asgns.t
Sourceval sequential_loop : f:(unit -> Base.unit) -> (Idx.static_symbol * int Base.ref) list -> Base.unit

All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

Distributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.

Sourceval round_robin_dry_run : num_devices:int -> (Idx.static_symbol * int Base.ref) list -> dry_sync:(int -> Base.unit) -> Base.unit
Sourceval set_virtual : Tn.t -> unit
Sourceval every_non_literal_on_host : Tensor.t -> Base.unit
Sourceval all_host_to_device : (module Backend_type with type context = 'context) -> 'context -> Tensor.t -> Base.unit
Sourceval all_device_to_host : (module Backend_type with type context = 'context) -> 'context -> Tensor.t -> Base.unit
Sourceval sync_run : ?looping:(unit -> Base.unit) -> (module Backend_type with type context = 'context) -> 'context Arrayjit__Backends.routine -> Tensor.t -> Base.unit

Executes the jitted code and copies arrays embedded in the given tenosor from and to host, synchronizes before copying to host. If looping is provided, loops over bindings and executes the given function inside the loop after a run. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

Sourcemodule Lazy = Utils.Lazy
Sourceval collapse_merges : ('a, 'b Base.Option.t Base.Array.t) Base.Hashtbl.t -> 'b list Base.Array.t
Sourceval parallel_update : (module Backend_type with type context = 'context) -> grad_updates:'context Arrayjit__Backends.routine Base.array -> sgd_update:'context Arrayjit__Backends.routine -> post_sync:(num_synced_devices:Base.int -> Base.unit) -> updaten -> Base.unit -> Base.unit

Performs one optimization step, potentially in parallel (if grad_updates are compiled for different devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of grad_updates in a round-robin fashion, and performs the following synchronization each time all grad_updates have been called:

1. merges all gradients into the device of grad_updates.(0), 2. calls sgd_update, 3. copies all parameters from the grad_updates.(0) device to the other devices, if needed, 4. calls post_sync with the number of devices synced since the previous sync.

All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.

Sourceval example_train_loop : ?disable_rootness_check:bool -> name:Base.String.t -> seed:Base.int -> batch_size:int -> init_lr:Base.float -> ?lr_schedule: (batch_n:Idx.static_symbol -> step_n:Idx.static_symbol -> Tensor.t) -> num_devices:int -> data_len:int -> epochs:int -> inputs:(b:int list -> Tensor.t) -> outputs:(b:int list -> Tensor.t) -> model:(Tensor.t -> Tensor.t) -> loss_fn:(output:Tensor.t -> expectation:Tensor.t -> Tensor.t) -> weight_decay:Base.float -> ?per_batch_callback: (at_batch:Base.int -> at_step:Base.int -> learning_rate:Base.float -> batch_loss:Base.float -> epoch_loss:float -> unit) -> ?per_epoch_callback: (at_step:Base.int -> at_epoch:int -> learning_rate:Base.float -> epoch_loss:float -> unit) -> (module Arrayjit.Backends.Backend) -> unit -> Tensor.t * Tensor.t * Tensor.t * (Base.float Base.array -> Base.float Base.array) * Base.float list * float list * Base.float list
Sourceval forward_and_forget : ?disable_rootness_check:Base.bool -> (module Backend_type with type context = 'context) -> 'context -> ?bindings:(Base.unit -> Base.unit) Idx.bindings -> Tensor.t -> Base.unit
OCaml

Innovation. Community. Security.