package fehu

  1. Overview
  2. Docs
Reinforcement learning framework for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha2.tbz
sha256=93abc49d075a1754442ccf495645bc4fdc83e4c66391ec8aca8fa15d2b4f44d2
sha512=5eb958c51f30ae46abded4c96f48d1825f79c7ce03f975f9a6237cdfed0d62c0b4a0774296694def391573d849d1f869919c49008acffca95946b818ad325f6f

doc/src/fehu.visualize/fehu_visualize.ml.html

Source file fehu_visualize.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
open Fehu
module Overlay = Overlay
module Video = Wrapper_video
module Sink = Sink

let push = Sink.push

let push_many sink frames =
  Array.iter (fun frame -> Sink.push sink frame) frames

let expect_image_frame source = function
  | Render.Image image -> image
  | Render.None ->
      invalid_arg
        (Printf.sprintf
           "%s: render returned None; ensure render_mode returns frames" source)
  | Render.Text _ ->
      invalid_arg
        (Printf.sprintf "%s: render produced ANSI text; expected image" source)
  | Render.Svg _ ->
      invalid_arg
        (Printf.sprintf "%s: render produced SVG; expected raster image" source)

let default_overlay overlay = Option.value overlay ~default:Overlay.identity

let record_rollout ~env ~policy ~steps ?overlay ~sink () =
  if steps <= 0 then invalid_arg "record_rollout: steps must be positive";
  let overlay = default_overlay overlay in
  let action_space = Env.action_space env in
  let observation, _info = Env.reset env () in
  let current_observation = ref observation in
  let episode_idx = ref 0 in
  let step_idx = ref 0 in
  let finalize () = Sink.close sink in
  Fun.protect ~finally:finalize (fun () ->
      for i = 0 to steps - 1 do
        let action, log_prob, value = policy !current_observation in
        let transition = Env.step env action in
        let frame =
          match Env.render env with
          | Some frame -> frame
          | None ->
              invalid_arg
                "record_rollout: Env.render returned None; choose `Rgb_array` \
                 render mode"
        in
        let image = expect_image_frame "record_rollout" frame in
        let ctx =
          {
            Overlay.step_idx = !step_idx;
            episode_idx = !episode_idx;
            info = transition.info;
            action = Some (Space.pack action_space action);
            value;
            log_prob;
            reward = transition.reward;
            done_ = transition.terminated || transition.truncated;
          }
        in
        let image = overlay image ctx in
        Sink.push sink (Render.Image image);
        incr step_idx;
        if ctx.done_ then (
          incr episode_idx;
          let obs, _info_reset = Env.reset env () in
          current_observation := obs;
          if i < steps - 1 then
            match Env.render env with
            | None -> ()
            | Some frame ->
                let image = expect_image_frame "record_rollout(reset)" frame in
                let reset_ctx =
                  {
                    Overlay.step_idx = !step_idx;
                    episode_idx = !episode_idx;
                    info = Info.empty;
                    action = None;
                    value = None;
                    log_prob = None;
                    reward = 0.;
                    done_ = false;
                  }
                in
                let image = overlay image reset_ctx in
                Sink.push sink (Render.Image image))
        else current_observation := transition.observation
      done)

let take_first n list =
  let rec aux count acc = function
    | _ when count = n -> List.rev acc
    | [] -> List.rev acc
    | x :: xs -> aux (count + 1) (x :: acc) xs
  in
  aux 0 [] list

let mean_and_std floats =
  match floats with
  | [] -> (0., 0.)
  | _ ->
      let n = float_of_int (List.length floats) in
      let sum = List.fold_left ( +. ) 0. floats in
      let mean = sum /. n in
      let variance =
        if List.length floats = 1 then 0.
        else
          List.fold_left
            (fun acc value ->
              let diff = value -. mean in
              acc +. (diff *. diff))
            0. floats
          /. n
      in
      (mean, Float.sqrt variance)

let mean_int ints =
  match ints with
  | [] -> 0.
  | _ ->
      let n = float_of_int (List.length ints) in
      let sum = List.fold_left ( + ) 0 ints in
      float_of_int sum /. n

let record_evaluation ~vec_env ~policy ~n_episodes ?max_steps ~layout ?overlay
    ~sink () =
  if n_episodes <= 0 then invalid_arg "record_evaluation: n_episodes > 0";
  let overlay = default_overlay overlay in
  let observations, _infos = Vector_env.reset vec_env () in
  let observations = ref observations in
  let num_envs = Vector_env.num_envs vec_env in
  let envs = Vector_env.envs vec_env in
  let action_space =
    if Array.length envs = 0 then
      invalid_arg "record_evaluation: empty vector environment"
    else Env.action_space envs.(0)
  in
  let returns = Array.make num_envs 0. in
  let lengths = Array.make num_envs 0 in
  let episode_indices = Array.make num_envs 0 in
  let completed_returns = ref [] in
  let completed_lengths = ref [] in
  let total_episodes = ref 0 in
  let step_idx = ref 0 in
  let finalize () = Sink.close sink in
  Fun.protect ~finally:finalize (fun () ->
      let rec loop () =
        if !total_episodes >= n_episodes then ()
        else
          match max_steps with
          | Some limit when !step_idx >= limit -> ()
          | _ ->
              let actions, log_probs_opt, values_opt = policy !observations in
              if Array.length actions <> num_envs then
                invalid_arg
                  "record_evaluation: policy returned mismatched action array";
              (match log_probs_opt with
              | Some arr when Array.length arr <> num_envs ->
                  invalid_arg
                    "record_evaluation: policy returned mismatched log_probs"
              | _ -> ());
              (match values_opt with
              | Some arr when Array.length arr <> num_envs ->
                  invalid_arg
                    "record_evaluation: policy returned mismatched values"
              | _ -> ());
              let step = Vector_env.step vec_env actions in
              let frames_opt = Vector_env.render vec_env in
              if Array.length frames_opt <> num_envs then
                invalid_arg
                  "record_evaluation: render frame count mismatch with envs";
              let frames =
                Array.map
                  (function
                    | None ->
                        invalid_arg
                          "record_evaluation: environment did not return frames"
                    | Some frame -> expect_image_frame "record_evaluation" frame)
                  frames_opt
              in
              for idx = 0 to num_envs - 1 do
                returns.(idx) <- returns.(idx) +. step.rewards.(idx);
                lengths.(idx) <- lengths.(idx) + 1;
                let done_flag =
                  step.terminations.(idx) || step.truncations.(idx)
                in
                let action_value = Space.pack action_space actions.(idx) in
                let value = Option.map (fun arr -> arr.(idx)) values_opt in
                let log_prob =
                  Option.map (fun arr -> arr.(idx)) log_probs_opt
                in
                let ctx =
                  {
                    Overlay.step_idx = !step_idx;
                    episode_idx = episode_indices.(idx);
                    info = step.infos.(idx);
                    action = Some action_value;
                    value;
                    log_prob;
                    reward = step.rewards.(idx);
                    done_ = done_flag;
                  }
                in
                frames.(idx) <- overlay frames.(idx) ctx;
                if done_flag then (
                  completed_returns := returns.(idx) :: !completed_returns;
                  completed_lengths := lengths.(idx) :: !completed_lengths;
                  returns.(idx) <- 0.;
                  lengths.(idx) <- 0;
                  episode_indices.(idx) <- episode_indices.(idx) + 1;
                  incr total_episodes)
              done;
              incr step_idx;
              let composed =
                match layout with
                | `Single_each ->
                    Utils.compose_grid ~rows:1 ~cols:num_envs frames
                | `NxM_grid (rows, cols) ->
                    if rows * cols <> num_envs then
                      invalid_arg
                        "record_evaluation: grid layout must cover all \
                         environments";
                    Utils.compose_grid ~rows ~cols frames
              in
              Sink.push sink (Render.Image composed);
              observations := step.observations;
              loop ()
      in
      loop ();
      let rewards = take_first n_episodes !completed_returns |> List.rev in
      let lengths = take_first n_episodes !completed_lengths |> List.rev in
      let mean_reward, std_reward = mean_and_std rewards in
      let mean_length = mean_int lengths in
      let open Training in
      { mean_reward; std_reward; mean_length; n_episodes = List.length rewards })