package kaun

  1. Overview
  2. Docs

Module Training.CallbacksSource

Callback system for training hooks

Sourcetype t

Abstract callback type

Sourcetype context = {
  1. epoch : int;
  2. state : Train_state.t;
  3. model : Layer.module_;
  4. optimizer : Optimizer.algorithm;
  5. history : History.t;
  6. train_loss : float option;
  7. val_loss : float option;
  8. train_metrics : (string * float) list;
  9. val_metrics : (string * float) list;
}

Context passed to callbacks

Sourceval early_stopping : ?monitor:string -> ?patience:int -> ?mode:[ `Min | `Max ] -> ?min_delta:float -> ?baseline:float option -> unit -> t

early_stopping ?monitor ?patience ?mode ?min_delta ?baseline () creates an early stopping callback.

  • monitor: Metric to monitor (default: "val_loss")
  • patience: Number of epochs with no improvement to wait (default: 5)
  • mode: Whether to minimize or maximize the metric (default: `Min)
  • min_delta: Minimum change to qualify as improvement (default: 0.0)
  • baseline: Baseline value; training stops if metric doesn't exceed it
Sourceval model_checkpoint : filepath:string -> ?monitor:string -> ?mode:[ `Min | `Max ] -> ?save_best_only:bool -> ?save_freq:[ `Epoch of int | `Best ] -> unit -> t

model_checkpoint ~filepath ?monitor ?mode ?save_best_only ?save_freq () creates a checkpoint callback.

  • filepath: Path pattern for saving checkpoints (can include {epoch} placeholder)
  • monitor: Metric to monitor for best model (default: "val_loss")
  • mode: Whether to minimize or maximize the metric (default: `Min)
  • save_best_only: Only save when monitored metric improves (default: true)
  • save_freq: Save frequency - every N epochs or only best (default: `Best)
Sourceval reduce_lr_on_plateau : ?monitor:string -> ?factor:float -> ?patience:int -> ?mode:[ `Min | `Max ] -> ?min_delta:float -> ?cooldown:int -> ?min_lr:float -> unit -> t

reduce_lr_on_plateau ?monitor ?factor ?patience ?mode ?min_delta ?cooldown ?min_lr () creates a learning rate reduction callback.

  • monitor: Metric to monitor (default: "val_loss")
  • factor: Factor by which to reduce learning rate (default: 0.1)
  • patience: Number of epochs with no improvement to wait (default: 10)
  • mode: Whether to minimize or maximize the metric (default: `Min)
  • min_delta: Minimum change to qualify as improvement (default: 0.0001)
  • cooldown: Number of epochs to wait before resuming normal operation (default: 0)
  • min_lr: Lower bound on learning rate (default: 0.0)
Sourceval tensorboard : log_dir:string -> ?update_freq:[ `Epoch | `Batch of int ] -> unit -> t

tensorboard ~log_dir ?update_freq () creates a TensorBoard logging callback.

  • log_dir: Directory where to save TensorBoard logs
  • update_freq: How often to write logs (default: `Epoch)
Sourceval custom : ?on_epoch_begin:(context -> bool) -> ?on_epoch_end:(context -> bool) -> ?on_train_begin:(context -> unit) -> ?on_train_end:(context -> unit) -> unit -> t

custom ?on_epoch_begin ?on_epoch_end ?on_train_begin ?on_train_end () creates a custom callback with user-defined hooks. Returning false from epoch callbacks stops training.

Sourceval combine : t list -> t

Combine multiple callbacks into one