package fehu

  1. Overview
  2. Docs
Reinforcement learning framework for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha2.tbz
sha256=93abc49d075a1754442ccf495645bc4fdc83e4c66391ec8aca8fa15d2b4f44d2
sha512=5eb958c51f30ae46abded4c96f48d1825f79c7ce03f975f9a6237cdfed0d62c0b4a0774296694def391573d849d1f869919c49008acffca95946b818ad325f6f

doc/src/fehu.visualize/wrapper_video.ml.html

Source file wrapper_video.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
open Fehu

type when_to_record = Every_n_episodes of int | Steps of (int -> bool)

let derive_id env suffix =
  match Env.id env with None -> None | Some id -> Some (id ^ suffix)

let mkdir_p path =
  let rec aux dir =
    if dir = "" || dir = "." || dir = Filename.dir_sep then ()
    else if Sys.file_exists dir then ()
    else (
      aux (Filename.dirname dir);
      try Unix.mkdir dir 0o755 with Unix.Unix_error (Unix.EEXIST, _, _) -> ())
  in
  aux path

let fps_of_metadata metadata =
  metadata.Metadata.render_fps |> Option.value ~default:30

let default_overlay overlay = Option.value overlay ~default:Overlay.identity

let expect_image_frame source = function
  | Render.Image image -> image
  | Render.None ->
      invalid_arg
        (Printf.sprintf
           "%s: render returned None; choose a render_mode that returns frames"
           source)
  | Render.Text _ ->
      invalid_arg
        (Printf.sprintf "%s: render produced text; expected image" source)
  | Render.Svg _ ->
      invalid_arg
        (Printf.sprintf "%s: render produced SVG; expected raster image" source)

let episode_file path episode_idx =
  Filename.concat path (Printf.sprintf "episode_%06d.mp4" (episode_idx + 1))

let step_file path step_idx =
  Filename.concat path (Printf.sprintf "step_%08d.mp4" step_idx)

let should_record_episode idx = function
  | Every_n_episodes n ->
      if n <= 0 then invalid_arg "record_video: n must be positive";
      (idx + 1) mod n = 0
  | Steps _ -> false

type single_state = {
  episode_idx : int;
  step_in_episode : int;
  global_step : int;
  sink : Sink.t option;
  recording : bool;
}

let record_video ~when_to_record ~path ?overlay env =
  mkdir_p path;
  let overlay = default_overlay overlay in
  let metadata = Env.metadata env in
  let fps = fps_of_metadata metadata in
  let action_space = Env.action_space env in
  let render_mode = Env.render_mode env in
  let state : single_state ref =
    ref
      {
        episode_idx = -1;
        step_in_episode = 0;
        global_step = 0;
        sink = None;
        recording = false;
      }
  in
  let close_sink () =
    match !state.sink with
    | None -> ()
    | Some sink ->
        Sink.close sink;
        state := { !state with sink = None; recording = false }
  in
  let open_episode_sink () =
    let file = episode_file path !state.episode_idx in
    state :=
      {
        !state with
        sink = Some (Sink.ffmpeg ~fps ~path:file ());
        recording = true;
      }
  in
  let open_step_sink step_idx =
    let file = step_file path step_idx in
    state :=
      {
        !state with
        sink = Some (Sink.ffmpeg ~fps ~path:file ());
        recording = true;
      }
  in
  let push_frame ~info ~action_opt ~reward ~is_done image =
    match !state.sink with
    | None -> ()
    | Some sink ->
        let ctx =
          {
            Overlay.step_idx = !state.step_in_episode;
            episode_idx = !state.episode_idx;
            info;
            action = action_opt;
            value = None;
            log_prob = None;
            reward;
            done_ = is_done;
          }
        in
        let image = overlay image ctx in
        Sink.push sink (Render.Image image)
  in
  let reset_handler _ ?options () =
    close_sink ();
    state :=
      { !state with episode_idx = !state.episode_idx + 1; step_in_episode = 0 };
    let observation, info = Env.reset env ?options () in
    if should_record_episode !state.episode_idx when_to_record then (
      open_episode_sink ();
      match Env.render env with
      | None -> ()
      | Some frame ->
          let image =
            expect_image_frame "Wrapper_video.record_video(reset)" frame
          in
          push_frame ~info ~action_opt:None ~reward:0. ~is_done:false image);
    (observation, info)
  in
  let step_handler _ action =
    let step_index = !state.global_step in
    let transition = Env.step env action in
    let done_flag = transition.terminated || transition.truncated in
    let action_value = Space.pack action_space action in
    let capture =
      match when_to_record with
      | Every_n_episodes _ -> !state.recording
      | Steps predicate ->
          let should_capture = predicate step_index in
          if should_capture && not !state.recording then (
            close_sink ();
            open_step_sink step_index)
          else if (not should_capture) && !state.recording then close_sink ();
          should_capture
    in
    (if capture then
       match Env.render env with
       | None ->
           invalid_arg
             "record_video: Env.render returned None while recording. Ensure \
              render_mode returns frames"
       | Some frame ->
           let image = expect_image_frame "Wrapper_video.record_video" frame in
           push_frame ~info:transition.info ~action_opt:(Some action_value)
             ~reward:transition.reward ~is_done:done_flag image);
    state :=
      {
        !state with
        global_step = step_index + 1;
        step_in_episode = (if done_flag then 0 else !state.step_in_episode + 1);
      };
    if done_flag then close_sink ();
    transition
  in
  Env.create
    ?id:(derive_id env "/VideoRecorder")
    ?render_mode ~metadata ~rng:(Env.rng env)
    ~observation_space:(Env.observation_space env)
    ~action_space ~reset:reset_handler ~step:step_handler
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ ->
      close_sink ();
      Env.close env)
    ()

(* The vectorised recorder reuses the single-env wrapper for Single_each layout
   and coordinates frames across environments for NxM grids. *)

type shared = {
  layout_rows : int;
  layout_cols : int;
  when_to_record : when_to_record;
  overlay : Overlay.t;
  base_path : string;
  fps : int;
  num_envs : int;
  frames : Render.image option array;
  step_in_episode : int array;
  episode_counts : int array;
  mutable sink : Sink.t option;
  mutable recording : bool;
  mutable frames_recorded : int;
  mutable global_step : int;
  mutable episode_idx : int;
}

let close_shared_sink shared =
  match shared.sink with
  | None -> ()
  | Some sink ->
      Sink.close sink;
      shared.sink <- None;
      shared.recording <- false

let open_shared_episode_sink shared =
  let file =
    Filename.concat shared.base_path
      (Printf.sprintf "episode_%06d.mp4" (shared.episode_idx + 1))
  in
  shared.sink <- Some (Sink.ffmpeg ~fps:shared.fps ~path:file ());
  shared.recording <- true

let open_shared_step_sink shared =
  let start = shared.global_step in
  let file =
    Filename.concat shared.base_path (Printf.sprintf "step_%08d.mp4" start)
  in
  shared.sink <- Some (Sink.ffmpeg ~fps:shared.fps ~path:file ());
  shared.recording <- true

let clear_frames shared =
  Array.fill shared.frames 0 shared.num_envs None;
  shared.frames_recorded <- 0

let flush_frames shared =
  if shared.frames_recorded = shared.num_envs then (
    (match shared.when_to_record with
    | Steps predicate ->
        let capture = predicate shared.global_step in
        if capture && not shared.recording then (
          close_shared_sink shared;
          open_shared_step_sink shared)
        else if (not capture) && shared.recording then close_shared_sink shared
    | Every_n_episodes _ -> ());
    (match shared.when_to_record with
    | Every_n_episodes _ -> (
        if shared.recording then
          let images =
            Array.map
              (function
                | Some image -> image
                | None -> failwith "wrapper_video: missing frame")
              shared.frames
          in
          let composed =
            Utils.compose_grid ~rows:shared.layout_rows ~cols:shared.layout_cols
              images
          in
          match shared.sink with
          | Some sink -> Sink.push sink (Render.Image composed)
          | None -> ())
    | Steps _ -> (
        if shared.recording then
          let images =
            Array.map
              (function
                | Some image -> image
                | None -> failwith "wrapper_video: missing frame")
              shared.frames
          in
          let composed =
            Utils.compose_grid ~rows:shared.layout_rows ~cols:shared.layout_cols
              images
          in
          match shared.sink with
          | Some sink -> Sink.push sink (Render.Image composed)
          | None -> ()));
    shared.global_step <- shared.global_step + 1;
    clear_frames shared)

let wrap_env_for_grid shared idx env =
  let action_space = Env.action_space env in
  let render_mode = Env.render_mode env in
  let metadata = Env.metadata env in
  let rng = Env.rng env in
  let reset_handler _ ?options () =
    shared.step_in_episode.(idx) <- 0;
    shared.episode_counts.(idx) <- shared.episode_counts.(idx) + 1;
    if idx = 0 then (
      shared.episode_idx <- shared.episode_idx + 1;
      close_shared_sink shared;
      if should_record_episode shared.episode_idx shared.when_to_record then
        open_shared_episode_sink shared);
    let observation, info = Env.reset env ?options () in
    (observation, info)
  in
  let step_handler _ action =
    let transition = Env.step env action in
    let done_flag = transition.terminated || transition.truncated in
    (match Env.render env with
    | None ->
        invalid_arg "vec_record_video: Env.render returned None while recording"
    | Some frame ->
        let image = expect_image_frame "Wrapper_video.vec_record_video" frame in
        let ctx =
          {
            Overlay.step_idx = shared.step_in_episode.(idx);
            episode_idx = shared.episode_counts.(idx);
            info = transition.info;
            action = Some (Space.pack action_space action);
            value = None;
            log_prob = None;
            reward = transition.reward;
            done_ = done_flag;
          }
        in
        let image = shared.overlay image ctx in
        shared.frames.(idx) <- Some image;
        shared.frames_recorded <- shared.frames_recorded + 1;
        flush_frames shared);
    shared.step_in_episode.(idx) <-
      (if done_flag then 0 else shared.step_in_episode.(idx) + 1);
    transition
  in
  Env.create
    ?id:(derive_id env "/VideoRecorder")
    ?render_mode ~metadata ~rng
    ~observation_space:(Env.observation_space env)
    ~action_space ~reset:reset_handler ~step:step_handler
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ ->
      if idx = 0 then close_shared_sink shared;
      Env.close env)
    ()

let vec_record_video ~layout ~when_to_record ~path ?overlay vec_env =
  mkdir_p path;
  let overlay = default_overlay overlay in
  match layout with
  | `Single_each ->
      let envs = Vector_env.envs vec_env in
      Array.iteri
        (fun idx env ->
          let subdir =
            Filename.concat path (Printf.sprintf "env_%02d" (idx + 1))
          in
          envs.(idx) <- record_video ~when_to_record ~path:subdir ~overlay env)
        envs;
      vec_env
  | `NxM_grid (rows, cols) ->
      let num_envs = Vector_env.num_envs vec_env in
      if rows * cols <> num_envs then
        invalid_arg "vec_record_video: grid layout must cover all environments";
      let metadata = Vector_env.metadata vec_env in
      let fps = fps_of_metadata metadata in
      let shared =
        {
          layout_rows = rows;
          layout_cols = cols;
          when_to_record;
          overlay;
          base_path = path;
          fps;
          num_envs;
          frames = Array.make num_envs None;
          step_in_episode = Array.make num_envs 0;
          episode_counts = Array.make num_envs (-1);
          sink = None;
          recording = false;
          frames_recorded = 0;
          global_step = 0;
          episode_idx = -1;
        }
      in
      let envs = Vector_env.envs vec_env in
      Array.iteri
        (fun idx env -> envs.(idx) <- wrap_env_for_grid shared idx env)
        envs;
      vec_env