Source file training.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
module History = struct
type t = {
train_loss : float list;
train_metrics : (string * float list) list;
val_loss : float list option;
val_metrics : (string * float list) list option;
}
let final_train_loss history =
match List.rev history.train_loss with [] -> None | hd :: _ -> Some hd
let final_val_loss history =
match history.val_loss with
| None -> None
| Some losses -> (
match List.rev losses with [] -> None | hd :: _ -> Some hd)
let final_train_metrics history =
List.map
(fun (name, values) ->
let final_value =
match List.rev values with [] -> 0.0 | hd :: _ -> hd
in
(name, final_value))
history.train_metrics
let final_val_metrics history =
match history.val_metrics with
| None -> []
| Some metrics ->
List.map
(fun (name, values) ->
let final_value =
match List.rev values with [] -> 0.0 | hd :: _ -> hd
in
(name, final_value))
metrics
let best_train_loss history =
match history.train_loss with
| [] -> None
| losses -> Some (List.fold_left min Float.max_float losses)
let best_val_loss history =
match history.val_loss with
| None -> None
| Some [] -> None
| Some losses -> Some (List.fold_left min Float.max_float losses)
let best_epoch ?(monitor = "val_loss") history =
let values =
if monitor = "val_loss" then Option.value history.val_loss ~default:[]
else if monitor = "train_loss" then history.train_loss
else
let find_in_metrics metrics_opt =
match metrics_opt with
| None -> []
| Some metrics ->
List.find_opt (fun (name, _) -> name = monitor) metrics
|> Option.map snd |> Option.value ~default:[]
in
if String.starts_with ~prefix:"val_" monitor then
find_in_metrics history.val_metrics
else find_in_metrics (Some history.train_metrics)
in
match values with
| [] -> None
| _ ->
let indexed = List.mapi (fun i v -> (i, v)) values in
let best_idx, _ =
List.fold_left
(fun (best_i, best_v) (i, v) ->
if v < best_v then (i, v) else (best_i, best_v))
(0, Float.max_float) indexed
in
Some best_idx
end
let merge_metric_history existing new_metrics =
let existing_names = List.map fst existing in
let updated_existing =
List.map
(fun (name, values) ->
match List.assoc_opt name new_metrics with
| Some value -> (name, values @ [ value ])
| None -> (name, values))
existing
in
let new_entries =
List.filter
(fun (name, _) -> not (List.mem name existing_names))
new_metrics
|> List.map (fun (name, value) -> (name, [ value ]))
in
updated_existing @ new_entries
let update_optional_metric_history existing_opt new_metrics =
match existing_opt with
| None ->
if new_metrics = [] then None
else Some (merge_metric_history [] new_metrics)
| Some existing ->
if new_metrics = [] then Some existing
else Some (merge_metric_history existing new_metrics)
let train_step ~model ~optimizer ~(state : Train_state.t) ~x ~y ~loss_fn =
let loss, grads =
Transformations.value_and_grad
(fun params ->
let logits = model.Layer.apply params ~training:true x in
loss_fn logits y)
state.params
in
let state = Train_state.apply_gradients ~optimizer ~grads state in
let logits = model.Layer.apply state.params ~training:false x in
Train_state.update_metrics state ~predictions:logits ~targets:y ~loss ();
(state, Rune.item [] loss)
let eval_step ~model ~(state : Train_state.t) ~x ~y ~loss_fn =
let logits = model.Layer.apply state.params ~training:false x in
let loss = loss_fn logits y in
Train_state.update_metrics state ~predictions:logits ~targets:y ~loss ();
Rune.item [] loss
let train_epoch ~model ~optimizer ~(state : Train_state.t) ~dataset ~loss_fn
?(progress = false) () =
let state = Train_state.reset_metrics state in
let state_ref = ref state in
let total_loss = ref 0. in
let batch_count = ref 0 in
let total_time = ref 0. in
if progress then Printf.printf "Training: ";
Dataset.iter
(fun (x, y) ->
incr batch_count;
let step_start = Unix.gettimeofday () in
let state', loss =
train_step ~model ~optimizer ~state:!state_ref ~x ~y ~loss_fn
in
let step_time = Unix.gettimeofday () -. step_start in
total_time := !total_time +. step_time;
state_ref := state';
total_loss := !total_loss +. loss;
if progress && !batch_count mod 10 = 0 then Printf.printf ".")
dataset;
if !batch_count = 0 then
invalid_arg
"Training.train_epoch: dataset produced no batches. Ensure your dataset \
yields at least one batch per epoch.";
(if progress then
let avg_step_time = !total_time /. float_of_int !batch_count *. 1000. in
Printf.printf " done (%d steps, avg %.1fms/step)\n%!" !batch_count
avg_step_time);
let avg_loss = !total_loss /. float_of_int !batch_count in
let metric_values = Train_state.compute_metrics !state_ref in
(!state_ref, avg_loss, metric_values)
module Callbacks = struct
type context = {
epoch : int;
state : Train_state.t;
model : Layer.module_;
optimizer : Optimizer.algorithm;
history : History.t;
train_loss : float option;
val_loss : float option;
train_metrics : (string * float) list;
val_metrics : (string * float) list;
}
type t = {
on_epoch_begin : context -> bool;
on_epoch_end : context -> bool;
on_train_begin : context -> unit;
on_train_end : context -> unit;
}
let early_stopping ?(monitor = "val_loss") ?(patience = 5) ?(mode = `Min)
?(min_delta = 0.0) ?(baseline = None) () =
let best_value = ref None in
let wait = ref 0 in
let stopped_epoch = ref 0 in
let is_better current best =
match mode with
| `Min -> current < best -. min_delta
| `Max -> current > best +. min_delta
in
let get_monitored_value ctx =
if monitor = "val_loss" then ctx.val_loss
else if monitor = "train_loss" then ctx.train_loss
else
let metrics =
if String.starts_with ~prefix:"val_" monitor then ctx.val_metrics
else ctx.train_metrics
in
List.find_opt (fun (name, _) -> name = monitor) metrics
|> Option.map snd
in
{
on_epoch_begin = (fun _ -> true);
on_epoch_end =
(fun ctx ->
match get_monitored_value ctx with
| None -> true
| Some current ->
let continue =
match baseline with
| Some b when not (is_better current b) ->
stopped_epoch := ctx.epoch;
false
| _ -> (
match !best_value with
| None ->
best_value := Some current;
wait := 0;
true
| Some best ->
if is_better current best then (
best_value := Some current;
wait := 0;
true)
else (
incr wait;
if !wait >= patience then (
stopped_epoch := ctx.epoch;
Printf.printf "\nEarly stopping at epoch %d\n"
ctx.epoch;
false)
else true))
in
continue);
on_train_begin = (fun _ -> ());
on_train_end = (fun _ -> ());
}
let model_checkpoint ~filepath ?(monitor = "val_loss") ?(mode = `Min)
?(save_best_only = true) ?(save_freq = `Best) () =
let best_value = ref None in
let is_better current best =
match mode with `Min -> current < best | `Max -> current > best
in
let get_monitored_value ctx =
if monitor = "val_loss" then ctx.val_loss
else if monitor = "train_loss" then ctx.train_loss
else
let metrics =
if String.starts_with ~prefix:"val_" monitor then ctx.val_metrics
else ctx.train_metrics
in
List.find_opt (fun (name, _) -> name = monitor) metrics
|> Option.map snd
in
let save_checkpoint ctx =
let path =
Str.global_replace (Str.regexp "{epoch}") (string_of_int ctx.epoch)
filepath
in
match Checkpoint.save_params_file ~path ~params:ctx.state.params with
| Ok () -> Printf.printf "Saved checkpoint to %s\n" path
| Error err ->
failwith
(Printf.sprintf "Failed to save checkpoint %s: %s" path
(Checkpoint.error_to_string err))
in
{
on_epoch_begin = (fun _ -> true);
on_epoch_end =
(fun ctx ->
let should_save =
match save_freq with
| `Epoch n when ctx.epoch mod n = 0 -> true
| `Best ->
if save_best_only then
match get_monitored_value ctx with
| None -> false
| Some current -> (
match !best_value with
| None ->
best_value := Some current;
true
| Some best ->
if is_better current best then (
best_value := Some current;
true)
else false)
else true
| _ -> false
in
if should_save then save_checkpoint ctx;
true);
on_train_begin = (fun _ -> ());
on_train_end = (fun _ -> ());
}
let reduce_lr_on_plateau ?(monitor = "val_loss") ?(factor = 0.1)
?(patience = 10) ?(mode = `Min) ?(min_delta = 0.0001) ?(cooldown = 0)
?(min_lr = 0.0) () =
let best_value = ref None in
let wait = ref 0 in
let cooldown_counter = ref 0 in
let current_lr = ref None in
let is_better current best =
match mode with
| `Min -> current < best -. min_delta
| `Max -> current > best +. min_delta
in
let get_monitored_value ctx =
if monitor = "val_loss" then ctx.val_loss
else if monitor = "train_loss" then ctx.train_loss
else
let metrics =
if String.starts_with ~prefix:"val_" monitor then ctx.val_metrics
else ctx.train_metrics
in
List.find_opt (fun (name, _) -> name = monitor) metrics
|> Option.map snd
in
{
on_epoch_begin = (fun _ -> true);
on_epoch_end =
(fun ctx ->
if !cooldown_counter > 0 then (
decr cooldown_counter;
true)
else
match get_monitored_value ctx with
| None -> true
| Some current -> (
match !best_value with
| None ->
best_value := Some current;
wait := 0;
true
| Some best ->
if is_better current best then (
best_value := Some current;
wait := 0;
true)
else (
incr wait;
if !wait >= patience then (
let new_lr =
match !current_lr with
| None ->
Printf.printf
"\n\
Would reduce learning rate by factor %.2f \
(min_lr: %.6f)\n"
factor min_lr;
None
| Some lr ->
let new_lr_value = lr *. factor in
if new_lr_value >= min_lr then (
Printf.printf
"\nReducing learning rate from %.6f to %.6f\n"
lr new_lr_value;
current_lr := Some new_lr_value;
Some new_lr_value)
else (
Printf.printf
"\n\
Learning rate %.6f already at minimum %.6f\n"
lr min_lr;
Some lr)
in
let _ = new_lr in
wait := 0;
cooldown_counter := cooldown;
true)
else true)));
on_train_begin = (fun _ -> ());
on_train_end = (fun _ -> ());
}
let tensorboard ~log_dir ?(update_freq = `Epoch) () =
let _ = Sys.command (Printf.sprintf "mkdir -p %s" log_dir) in
let batch_counter = ref 0 in
{
on_epoch_begin =
(fun _ ->
batch_counter := 0;
true);
on_epoch_end =
(fun ctx ->
let should_log =
match update_freq with
| `Epoch -> true
| `Batch n -> !batch_counter mod n = 0
in
if should_log then (
let log_file = Filename.concat log_dir "metrics.log" in
let oc = open_out_gen [ Open_append; Open_creat ] 0o644 log_file in
Printf.fprintf oc "Epoch %d: " ctx.epoch;
(match ctx.train_loss with
| Some loss -> Printf.fprintf oc "train_loss=%.4f " loss
| None -> ());
List.iter
(fun (name, value) ->
Printf.fprintf oc "train_%s=%.4f " name value)
ctx.train_metrics;
(match ctx.val_loss with
| Some loss -> Printf.fprintf oc "val_loss=%.4f " loss
| None -> ());
List.iter
(fun (name, value) -> Printf.fprintf oc "val_%s=%.4f " name value)
ctx.val_metrics;
Printf.fprintf oc "\n";
close_out oc);
incr batch_counter;
true);
on_train_begin = (fun _ -> ());
on_train_end = (fun _ -> ());
}
let custom ?(on_epoch_begin = fun _ -> true) ?(on_epoch_end = fun _ -> true)
?(on_train_begin = fun _ -> ()) ?(on_train_end = fun _ -> ()) () =
{ on_epoch_begin; on_epoch_end; on_train_begin; on_train_end }
let combine callbacks =
{
on_epoch_begin =
(fun ctx -> List.for_all (fun cb -> cb.on_epoch_begin ctx) callbacks);
on_epoch_end =
(fun ctx -> List.for_all (fun cb -> cb.on_epoch_end ctx) callbacks);
on_train_begin =
(fun ctx -> List.iter (fun cb -> cb.on_train_begin ctx) callbacks);
on_train_end =
(fun ctx -> List.iter (fun cb -> cb.on_train_end ctx) callbacks);
}
end
let evaluate ~model ~(state : Train_state.t) ~dataset ~loss_fn
?(progress = false) () =
let state = Train_state.reset_metrics state in
let total_loss = ref 0. in
let batch_count = ref 0 in
if progress then Printf.printf "Evaluating: ";
Dataset.iter
(fun (x, y) ->
incr batch_count;
let loss = eval_step ~model ~state ~x ~y ~loss_fn in
total_loss := !total_loss +. loss;
if progress && !batch_count mod 10 = 0 then Printf.printf ".")
dataset;
if progress then Printf.printf " done\n%!";
if !batch_count = 0 then
invalid_arg
"Training.evaluate: dataset produced no batches. Ensure your validation \
dataset yields at least one batch.";
let avg_loss = !total_loss /. float_of_int !batch_count in
let metrics = Train_state.compute_metrics state in
(avg_loss, metrics)
let fit ~model ~optimizer ~loss_fn ?metrics ~train_data ?val_data ~epochs
?callbacks ?(progress = true) ~rngs ~dtype () =
let state = Train_state.init ~model ~optimizer ?metrics ~rngs ~dtype () in
let history =
History.
{
train_loss = [];
train_metrics = [];
val_loss = None;
val_metrics = None;
}
in
let state_ref = ref state in
let history_ref = ref history in
let callback =
match callbacks with
| None -> None
| Some cbs -> Some (Callbacks.combine cbs)
in
(match callback with
| Some cb ->
let ctx =
Callbacks.
{
epoch = 0;
state = !state_ref;
model;
optimizer;
history = !history_ref;
train_loss = None;
val_loss = None;
train_metrics = [];
val_metrics = [];
}
in
cb.on_train_begin ctx
| None -> ());
let continue_training = ref true in
let epoch_idx = ref 1 in
while !epoch_idx <= epochs && !continue_training do
let epoch = !epoch_idx in
if progress then Printf.printf "\nEpoch %d/%d\n" epoch epochs;
let epoch_start_time = Unix.gettimeofday () in
Dataset.reset train_data;
(match val_data with Some ds -> Dataset.reset ds | None -> ());
(match callback with
| Some cb ->
let ctx =
Callbacks.
{
epoch;
state = !state_ref;
model;
optimizer;
history = !history_ref;
train_loss = None;
val_loss = None;
train_metrics = [];
val_metrics = [];
}
in
continue_training := cb.on_epoch_begin ctx
| None -> ());
if !continue_training then (
let state', train_loss, train_metrics =
train_epoch ~model ~optimizer ~state:!state_ref ~dataset:train_data
~loss_fn ~progress ()
in
state_ref := state';
if progress then Printf.printf " Train loss: %.4f" train_loss;
List.iter
(fun (name, value) ->
if progress then Printf.printf ", %s: %.4f" name value)
train_metrics;
if progress then Printf.printf "\n";
history_ref :=
{
!history_ref with
train_loss = !history_ref.train_loss @ [ train_loss ];
train_metrics =
merge_metric_history !history_ref.train_metrics train_metrics;
};
(match val_data with
| Some val_dataset ->
let val_loss, val_metrics =
evaluate ~model ~state:!state_ref ~dataset:val_dataset ~loss_fn
~progress ()
in
if progress then Printf.printf " Val loss: %.4f" val_loss;
List.iter
(fun (name, value) ->
if progress then Printf.printf ", %s: %.4f" name value)
val_metrics;
if progress then Printf.printf "\n%!";
let val_loss_list =
match !history_ref.val_loss with
| Some l -> l @ [ val_loss ]
| None -> [ val_loss ]
in
let val_metrics_list =
update_optional_metric_history !history_ref.val_metrics val_metrics
in
history_ref :=
{
!history_ref with
val_loss = Some val_loss_list;
val_metrics = val_metrics_list;
}
| None -> ());
let epoch_time = Unix.gettimeofday () -. epoch_start_time in
if progress then Printf.printf " Time: %.2fs\n%!" epoch_time;
match callback with
| Some cb ->
let train_metrics_final = train_metrics in
let val_metrics_final =
match val_data with
| None -> []
| Some _ -> (
match !history_ref.val_metrics with
| Some metrics ->
List.map
(fun (name, values) ->
match List.rev values with
| [] -> (name, 0.0)
| hd :: _ -> (name, hd))
metrics
| None -> [])
in
let val_loss_final =
match val_data with
| None -> None
| Some _ -> History.final_val_loss !history_ref
in
let ctx =
Callbacks.
{
epoch;
state = !state_ref;
model;
optimizer;
history = !history_ref;
train_loss = Some train_loss;
val_loss = val_loss_final;
train_metrics = train_metrics_final;
val_metrics = val_metrics_final;
}
in
if cb.on_epoch_end ctx then incr epoch_idx
else continue_training := false
| None -> incr epoch_idx)
else incr epoch_idx
done;
(match callback with
| Some cb ->
let ctx =
Callbacks.
{
epoch = epochs;
state = !state_ref;
model;
optimizer;
history = !history_ref;
train_loss = History.final_train_loss !history_ref;
val_loss = History.final_val_loss !history_ref;
train_metrics = History.final_train_metrics !history_ref;
val_metrics = History.final_val_metrics !history_ref;
}
in
cb.on_train_end ctx
| None -> ());
(!state_ref, !history_ref)