package kaun

  1. Overview
  2. Docs

Source file kaun.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
type 'layout tensor = (float, 'layout) Rune.t
type 'layout dtype = (float, 'layout) Rune.dtype

(* Parameter tree - alias for Ptree.t *)
type params = Ptree.t =
  | Tensor of Ptree.tensor
  | List of params list
  | Dict of (string * params) list

type module_ = Layer.module_ = {
  init :
    'layout. rngs:Rune.Rng.key -> dtype:(float, 'layout) Rune.dtype -> Ptree.t;
  apply :
    'layout.
    Ptree.t ->
    training:bool ->
    ?rngs:Rune.Rng.key ->
    (float, 'layout) Rune.t ->
    (float, 'layout) Rune.t;
}

let init m ~rngs ~dtype = m.init ~rngs ~dtype
let apply m params ~training ?rngs x = m.apply params ~training ?rngs x
let value_and_grad = Transformations.value_and_grad
let grad = Transformations.grad

module Metrics = Metrics
module Loss = Loss
module Initializers = Initializers
module Attention = Attention
module Layer = Layer
module Checkpoint = Checkpoint
module Train_state = Train_state
module Ptree = Ptree
module Optimizer = Optimizer
module Activations = Activations
module Dataset = Dataset
module Training = Training