package torch

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file checkpointing.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
open Base

let latest_index_and_filename ~checkpoint_base =
  let dirname = Stdlib.Filename.dirname checkpoint_base in
  let basename = Stdlib.Filename.basename checkpoint_base in
  Stdlib.Sys.readdir dirname
  |> Array.to_list
  |> List.filter_map ~f:(fun filename ->
    match String.chop_prefix filename ~prefix:(basename ^ ".") with
    | None -> None
    | Some suffix ->
      (try Some (Int.of_string suffix, Stdlib.Filename.concat dirname filename) with
       | _ -> None))
  |> List.sort ~compare:Stdlib.compare
  |> List.last
;;

let loop
      ~start_index
      ~end_index
      ~var_stores
      ~checkpoint_base
      ?only_keep
      ?(checkpoint_every = `seconds 600.)
      f
  =
  if start_index < 0 then Printf.invalid_argf "negative start_index %d" start_index ();
  Option.iter only_keep ~f:(fun only_keep ->
    if only_keep <= 0 then Printf.invalid_argf "non-positive only_keep %d" only_keep ());
  let temp_checkpoint = checkpoint_base ^ ".tmp" in
  let latest_index_and_filename = latest_index_and_filename ~checkpoint_base in
  let named_tensors =
    match var_stores with
    | [ vs ] -> Var_store.all_vars vs
    | var_stores ->
      List.concat_map var_stores ~f:(fun vs ->
        let vs_name = Var_store.name vs in
        Var_store.all_vars vs
        |> List.map ~f:(fun (name, tensor) -> Printf.sprintf "%s:%s" vs_name name, tensor))
  in
  Option.iter latest_index_and_filename ~f:(fun (latest_index, filename) ->
    Stdio.eprintf "Restoring checkpoint for index %d from '%s'.\n%!" latest_index filename;
    Serialize.load_multi_ ~named_tensors ~filename);
  let start_index =
    Option.value_map latest_index_and_filename ~default:start_index ~f:(fun (index, _) ->
      index + 1)
  in
  let only_keep =
    Option.map only_keep ~f:(fun only_keep -> only_keep, Linked_queue.create ())
  in
  let save ~suffix =
    Serialize.save_multi ~named_tensors ~filename:temp_checkpoint;
    Unix.rename temp_checkpoint (Printf.sprintf "%s.%s" checkpoint_base suffix)
  in
  let save_index index =
    save ~suffix:(Int.to_string index);
    Option.iter only_keep ~f:(fun (only_keep, index_queue) ->
      Linked_queue.enqueue index_queue index;
      if Linked_queue.length index_queue > only_keep
      then
        Linked_queue.dequeue_exn index_queue
        |> Int.to_string
        |> Printf.sprintf "%s.%s" checkpoint_base
        |> Unix.unlink)
  in
  let last_checkpoint_time = ref (Unix.time ()) in
  for index = start_index to end_index do
    f ~index;
    let should_checkpoint =
      match checkpoint_every with
      | `seconds seconds -> Float.( > ) (Unix.time () -. !last_checkpoint_time) seconds
      | `iters iters -> index % iters = 0
    in
    if should_checkpoint
    then (
      save_index index;
      last_checkpoint_time := Unix.time ())
  done;
  save ~suffix:"final"
;;