package kaun

  1. Overview
  2. Docs

Module Kaun.TrainingSource

High-level training utilities operating on Train_state.

Sourcemodule History : sig ... end

Helper functions for accessing training history.

Sourcemodule Callbacks : sig ... end

Callback system for training hooks

Sourceval train_step : model:Layer.module_ -> optimizer:Optimizer.algorithm -> state:Train_state.t -> x:(float, 'layout) Rune.t -> y:(float, 'layout) Rune.t -> loss_fn: ((float, 'layout) Rune.t -> (float, 'layout) Rune.t -> (float, 'layout) Rune.t) -> Train_state.t * float

Perform a single training step, returning the updated state and scalar loss.

Sourceval eval_step : model:Layer.module_ -> state:Train_state.t -> x:(float, 'layout) Rune.t -> y:(float, 'layout) Rune.t -> loss_fn: ((float, 'layout) Rune.t -> (float, 'layout) Rune.t -> (float, 'layout) Rune.t) -> float

Evaluate loss without mutating state.

Sourceval train_epoch : model:Layer.module_ -> optimizer:Optimizer.algorithm -> state:Train_state.t -> dataset:((float, 'layout) Rune.t * (float, 'layout) Rune.t) Dataset.t -> loss_fn: ((float, 'layout) Rune.t -> (float, 'layout) Rune.t -> (float, 'layout) Rune.t) -> ?progress:bool -> unit -> Train_state.t * float * (string * float) list

Run one training epoch and report average loss and metrics.

Sourceval evaluate : model:Layer.module_ -> state:Train_state.t -> dataset:((float, 'layout) Rune.t * (float, 'layout) Rune.t) Dataset.t -> loss_fn: ((float, 'layout) Rune.t -> (float, 'layout) Rune.t -> (float, 'layout) Rune.t) -> ?progress:bool -> unit -> float * (string * float) list

Evaluate over a dataset, returning average loss and metrics.

Sourceval fit : model:Layer.module_ -> optimizer:Optimizer.algorithm -> loss_fn: ((float, 'layout) Rune.t -> (float, 'layout) Rune.t -> (float, 'layout) Rune.t) -> ?metrics:Metrics.Collection.t -> train_data:((float, 'layout) Rune.t * (float, 'layout) Rune.t) Dataset.t -> ?val_data:((float, 'layout) Rune.t * (float, 'layout) Rune.t) Dataset.t -> epochs:int -> ?callbacks:Callbacks.t list -> ?progress:bool -> rngs:Rune.Rng.key -> dtype:(float, 'layout) Rune.dtype -> unit -> Train_state.t * History.t

Train for multiple epochs, returning the final state and accumulated history.