package kaun

  1. Overview
  2. Docs

Module Kaun.OptimizerSource

Optax-inspired optimisation algorithms with explicit checkpoint support.

Sourcetype state

Optimiser state - opaque container storing algorithm-specific payloads.

Sourcetype algorithm

Optimisation algorithm acting on parameter trees.

Sourceval name : algorithm -> string

Human-readable name for the algorithm.

Sourceval init : algorithm -> Ptree.t -> state

Initialise optimiser state for the given parameters.

Apply the algorithm to gradients, returning parameter updates and the next optimiser state.

Serialise an optimiser state into a checkpoint snapshot.

Sourceval restore : algorithm -> Checkpoint.Snapshot.t -> (state, string) result

Restore optimiser state for the given algorithm from a checkpoint.

Sourceval step_count : state -> int option

Extract a step counter from an optimiser state when available (e.g. Adam).

Algorithm building blocks

Sourceval identity : unit -> algorithm
Sourceval scale : float -> algorithm
Sourceval scale_by_neg_one : unit -> algorithm
Sourceval add_decayed_weights : float -> algorithm
Sourceval clip_by_global_norm : float -> algorithm
Sourceval clip : float -> algorithm
Sourceval trace : decay:float -> ?nesterov:bool -> unit -> algorithm
Sourceval scale_by_rms : ?decay:float -> ?eps:float -> unit -> algorithm
Sourceval scale_by_adam : ?b1:float -> ?b2:float -> ?eps:float -> unit -> algorithm
Sourceval scale_by_belief : ?b1:float -> ?b2:float -> ?eps:float -> unit -> algorithm

Learning rate schedules

Sourcemodule Schedule : sig ... end
Sourceval scale_by_schedule : Schedule.t -> algorithm

Composition helpers

Sourceval chain : algorithm list -> algorithm
Sourcetype label_tree =
  1. | Label_tensor of int
  2. | Label_list of label_tree list
  3. | Label_record of (string * label_tree) list
Sourcetype mask_tree =
  1. | Mask_tensor of bool
  2. | Mask_list of mask_tree list
  3. | Mask_record of (string * mask_tree) list
Sourceval multi_transform : transforms:algorithm list -> labels:(Ptree.t -> label_tree) -> algorithm

Applies different transforms based on labels computed from params.

Sourceval masked : mask:(Ptree.t -> mask_tree) -> inner:algorithm -> algorithm

Masks gradients/updates based on a function over params.

Utility functions

Sourceval apply_updates : Ptree.t -> Ptree.t -> Ptree.t
Sourceval apply_updates_inplace : Ptree.t -> Ptree.t -> unit
Sourceval global_norm : Ptree.t -> float
Sourceval set_to_zero : Ptree.t -> Ptree.t
Sourceval multi_steps : every:int -> algorithm -> algorithm
Sourceval with_gradient_stats : ?prefix:string -> algorithm -> algorithm

Pre-configured optimisers

Sourceval sgd : lr:Schedule.t -> ?momentum:float -> ?nesterov:bool -> unit -> algorithm
Sourceval adam : lr:Schedule.t -> ?b1:float -> ?b2:float -> ?eps:float -> unit -> algorithm
Sourceval adamw : lr:Schedule.t -> ?b1:float -> ?b2:float -> ?eps:float -> ?weight_decay:float -> unit -> algorithm
Sourceval rmsprop : lr:Schedule.t -> ?decay:float -> ?eps:float -> ?momentum:float -> unit -> algorithm
Sourceval adagrad : lr:Schedule.t -> ?eps:float -> unit -> algorithm
Sourceval adabelief : lr:Schedule.t -> ?b1:float -> ?b2:float -> ?eps:float -> unit -> algorithm
Sourceval lamb : lr:Schedule.t -> ?b1:float -> ?b2:float -> ?eps:float -> ?weight_decay:float -> unit -> algorithm
Sourceval radam : lr:Schedule.t -> ?b1:float -> ?b2:float -> ?eps:float -> unit -> algorithm
Sourceval yogi : lr:Schedule.t -> ?b1:float -> ?b2:float -> ?eps:float -> unit -> algorithm