package kaun

  1. Overview
  2. Docs

Source file train_state.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
module Snapshot = Checkpoint.Snapshot

type t = {
  step : int;
  params : Ptree.t;
  opt_state : Optimizer.state;
  rng : Rune.Rng.key;
  metrics : Metrics.Collection.t option;
}

let init ~model ~optimizer ?metrics ~rngs ~dtype () =
  let params = model.Layer.init ~rngs ~dtype in
  let opt_state = Optimizer.init optimizer params in
  { step = 0; params; opt_state; rng = rngs; metrics }

let create ?(step = 0) ~params ~opt_state ~rng ?metrics () =
  { step; params; opt_state; rng; metrics }

let apply_gradients ~optimizer ~grads state =
  let updates, opt_state =
    Optimizer.step optimizer state.opt_state state.params grads
  in
  Optimizer.apply_updates_inplace state.params updates;
  { state with opt_state; step = state.step + 1 }

let next_rng state =
  let split = Rune.Rng.split state.rng in
  if Array.length split < 2 then
    invalid_arg "Train_state.next_rng: expected Rune.Rng.split to return 2 keys";
  (split.(0), { state with rng = split.(1) })

let reset_metrics state =
  (match state.metrics with
  | Some metrics -> Metrics.Collection.reset metrics
  | None -> ());
  state

let update_metrics state ~predictions ~targets ?loss ?weights () =
  match state.metrics with
  | None -> ()
  | Some metrics ->
      Metrics.Collection.update metrics ~predictions ~targets ?loss ?weights ()

let compute_metrics state =
  match state.metrics with
  | None -> []
  | Some metrics -> Metrics.Collection.compute metrics

let schema_key = "schema"
let schema_value = "kaun.train_state/1"
let checkpoint_slug = "state"

let to_snapshot ?encode_metrics
    ({ step; params; opt_state; rng = rng_key; metrics } : t) =
  let open Snapshot in
  let base_entries =
    [
      (schema_key, string schema_value);
      ("step", int step);
      ("params", ptree params);
      ("optimizer_state", Optimizer.serialize opt_state);
      ("rng", Snapshot.rng rng_key);
    ]
  in
  match (metrics, encode_metrics) with
  | None, _ -> record base_entries
  | Some metrics_value, Some encode ->
      record (("metrics", encode metrics_value) :: base_entries)
  | Some _, None ->
      invalid_arg
        "Train_state.to_snapshot: metrics present but encode_metrics missing"

let of_snapshot ~optimizer ?decode_metrics snapshot =
  let open Result in
  let open Snapshot in
  let ( let* ) = bind in
  let error msg = Error ("Train_state.of_snapshot: " ^ msg) in

  let* record =
    match snapshot with Record r -> Ok r | _ -> error "expected record"
  in

  let validate_schema record =
    match Snapshot.Record.find_opt schema_key record with
    | None -> Ok ()
    | Some (Scalar (String value)) ->
        if String.equal value schema_value then Ok ()
        else error ("unsupported schema " ^ value)
    | Some _ -> error "invalid schema field"
  in

  let* () = validate_schema record in

  let find field =
    match Snapshot.Record.find_opt field record with
    | Some value -> Ok value
    | None -> error ("missing field " ^ field)
  in

  let decode_step = function
    | Scalar (Int i) -> Ok i
    | Scalar (Float f) -> Ok (int_of_float f)
    | _ -> error "expected integer step"
  in

  let decode_rng = function
    | Scalar (Int seed) -> Ok (Rune.Rng.key seed)
    | Scalar (Float f) -> Ok (Rune.Rng.key (int_of_float f))
    | _ -> error "expected RNG scalar"
  in

  let decode_metrics_field () =
    match Snapshot.Record.find_opt "metrics" record with
    | None -> Ok None
    | Some value -> (
        match decode_metrics with
        | None ->
            error
              "metrics present but decode_metrics missing; provide a decoder"
        | Some decode -> (
            match decode value with
            | Ok metrics -> Ok (Some metrics)
            | Error msg -> error msg))
  in

  let* params_node = find "params" in
  let* params =
    match Snapshot.to_ptree params_node with
    | Ok params -> Ok params
    | Error msg -> error msg
  in

  let* opt_state_node = find "optimizer_state" in
  let* opt_state = Optimizer.restore optimizer opt_state_node in

  let* rng_node = find "rng" in
  let* rng = decode_rng rng_node in

  let* step_node = find "step" in
  let* step = decode_step step_node in

  let* metrics = decode_metrics_field () in
  Ok { step; params; opt_state; rng; metrics }

let make_artifact ?encode_metrics state =
  let snapshot = to_snapshot ?encode_metrics state in
  Checkpoint.artifact ~label:"state" ~kind:(Checkpoint.Custom checkpoint_slug)
    ~snapshot ()

let find_artifact artifacts =
  List.find_map
    (fun artifact ->
      if String.equal (Checkpoint.artifact_slug artifact) checkpoint_slug then
        Some artifact
      else None)
    artifacts

let save ~repository ?step ?tags ?metadata ?encode_metrics state =
  let step = Option.value ~default:state.step step in
  let metadata = Option.value ~default:[] metadata in
  let artifact = make_artifact ?encode_metrics state in
  Checkpoint.write repository ~step ?tags ~metadata ~artifacts:[ artifact ]

let load ~repository ~step ~optimizer ?decode_metrics () =
  match Checkpoint.read repository ~step with
  | Error err -> Error ("Train_state.load: " ^ Checkpoint.error_to_string err)
  | Ok (_manifest, artifacts) -> (
      match find_artifact artifacts with
      | None -> Error "Train_state.load: missing state artifact"
      | Some artifact ->
          let snapshot = Checkpoint.artifact_snapshot artifact in
          of_snapshot ?decode_metrics ~optimizer snapshot)