package neural_nets_lib
Install
dune-project
Dependency
Authors
Maintainers
Sources
md5=9170d4d98422350c9a73a95adfb795dc
sha512=c1b024a69b1d0338af6e34508dbf6dccf3c2b6cc156e7628c3d7853c7040e225bdfc0a8731bb4db5a97edba90e26439987bfa505154d23af46f119c07ad809ed
doc/neural_nets_lib/Ocannl/Train/index.html
Module Ocannl.TrainSource
Reinitializes a backend selected via a global backend flag.
Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment.
type updaten = {loss : Tensor.t;label : Base.string;params : (Tensor.t, Tensor.comparator_witness) Base.Set.t;fwd_bprop : Asgns.t;
}val grad_update :
?disable_rootness_check:bool ->
?setup_for_parallel:bool ->
Tensor.t ->
updatenReturns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. Sets the tensor's value as "fully on host". If setup_for_parallel is true (false by default), sets the parameters and their gradients as "non-local" (on-device).
val sgd_one :
learning_rate:Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
Tensor.t ->
Asgns.tSee: https://github.com/tinygrad/tinygrad/blob/master/tinygrad/nn/optim.py
val sgd_update :
learning_rate:Tensor.t ->
?momentum:Base.Float.t ->
?weight_decay:Base.float ->
?nesterov:bool ->
updaten ->
Asgns.tval sequential_loop :
f:(unit -> Base.unit) ->
(Idx.static_symbol * int Base.ref) list ->
Base.unitAll and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val round_robin :
(unit -> 'a) Base.Array.t ->
Idx.jitted_bindings Base.Array.t ->
(Idx.static_symbol * Base.int Base.ref) list ->
sync:(Base.int -> Base.unit) ->
Base.unitDistributes iterated indices to workers in a round-robin fashion. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values. sync is called after each round of calling all workers, and at the end if needed, with the number of workers called during the round.
val round_robin_dry_run :
num_devices:int ->
(Idx.static_symbol * int Base.ref) list ->
dry_sync:(int -> Base.unit) ->
Base.unitval all_host_to_device :
(module Backend_type with type context = 'context) ->
'context ->
Tensor.t ->
Base.unitval all_device_to_host :
(module Backend_type with type context = 'context) ->
'context ->
Tensor.t ->
Base.unitval sync_run :
?looping:(unit -> Base.unit) ->
(module Backend_type with type context = 'context) ->
'context Arrayjit__Backends.routine ->
Tensor.t ->
Base.unitExecutes the jitted code and copies arrays embedded in the given tenosor from and to host, synchronizes before copying to host. If looping is provided, loops over bindings and executes the given function inside the loop after a run. All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val collapse_merges :
('a, 'b Base.Option.t Base.Array.t) Base.Hashtbl.t ->
'b list Base.Array.tval parallel_update :
(module Backend_type with type context = 'context) ->
grad_updates:'context Arrayjit__Backends.routine Base.array ->
sgd_update:'context Arrayjit__Backends.routine ->
post_sync:(num_synced_devices:Base.int -> Base.unit) ->
updaten ->
Base.unit ->
Base.unitPerforms one optimization step, potentially in parallel (if grad_updates are compiled for different devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of grad_updates in a round-robin fashion, and performs the following synchronization each time all grad_updates have been called:
1. merges all gradients into the device of grad_updates.(0), 2. calls sgd_update, 3. copies all parameters from the grad_updates.(0) device to the other devices, if needed, 4. calls post_sync with the number of devices synced since the previous sync.
All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their initial values.
val example_train_loop :
?disable_rootness_check:bool ->
name:Base.String.t ->
seed:Base.int ->
batch_size:int ->
init_lr:Base.float ->
?lr_schedule:
(batch_n:Idx.static_symbol -> step_n:Idx.static_symbol -> Tensor.t) ->
num_devices:int ->
data_len:int ->
epochs:int ->
inputs:(b:int list -> Tensor.t) ->
outputs:(b:int list -> Tensor.t) ->
model:(Tensor.t -> Tensor.t) ->
loss_fn:(output:Tensor.t -> expectation:Tensor.t -> Tensor.t) ->
weight_decay:Base.float ->
?per_batch_callback:
(at_batch:Base.int ->
at_step:Base.int ->
learning_rate:Base.float ->
batch_loss:Base.float ->
epoch_loss:float ->
unit) ->
?per_epoch_callback:
(at_step:Base.int ->
at_epoch:int ->
learning_rate:Base.float ->
epoch_loss:float ->
unit) ->
(module Arrayjit.Backends.Backend) ->
unit ->
Tensor.t
* Tensor.t
* Tensor.t
* (Base.float Base.array ->
Base.float Base.array)
* Base.float list
* float list
* Base.float listval forward_and_forget :
?disable_rootness_check:Base.bool ->
(module Backend_type with type context = 'context) ->
'context ->
?bindings:(Base.unit -> Base.unit) Idx.bindings ->
Tensor.t ->
Base.unit