package fehu

  1. Overview
  2. Docs

Source file wrapper.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
let derive_id env suffix =
  match Env.id env with None -> None | Some id -> Some (id ^ suffix)

let inherit_metadata env = Env.metadata env

let map_observation ~(observation_space : 'obs Space.t)
    ~(f : 'inner_obs -> Info.t -> 'obs * Info.t)
    (env : ('inner_obs, 'act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
  Env.create
    ?id:(derive_id env "/ObservationWrapper")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () ->
      let observation, info = Env.reset env ?options () in
      f observation info)
    ~step:(fun _wrapper action ->
      let transition = Env.step env action in
      let observation, info = f transition.observation transition.info in
      Env.transition ~observation ~reward:transition.reward
        ~terminated:transition.terminated ~truncated:transition.truncated ~info
        ())
    ()

let map_action ~(action_space : 'act Space.t) ~(f : 'act -> 'inner_act)
    (env : ('obs, 'inner_act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
  Env.create
    ?id:(derive_id env "/ActionWrapper")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space
    ~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
    ~step:(fun _wrapper action ->
      let transition = Env.step env (f action) in
      Env.transition ~observation:transition.observation
        ~reward:transition.reward ~terminated:transition.terminated
        ~truncated:transition.truncated ~info:transition.info ())
    ()

let map_reward ~(f : reward:float -> info:Info.t -> float * Info.t)
    (env : ('obs, 'act, 'render) Env.t) : ('obs, 'act, 'render) Env.t =
  Env.create
    ?id:(derive_id env "/RewardWrapper")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
    ~step:(fun _wrapper action ->
      let transition = Env.step env action in
      let reward, info = f ~reward:transition.reward ~info:transition.info in
      { transition with reward; info })
    ()

let map_info ~(f : Info.t -> Info.t) (env : ('obs, 'act, 'render) Env.t) :
    ('obs, 'act, 'render) Env.t =
  Env.create
    ?id:(derive_id env "/InfoWrapper")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () ->
      let observation, info = Env.reset env ?options () in
      (observation, f info))
    ~step:(fun _wrapper action ->
      let transition = Env.step env action in
      { transition with info = f transition.info })
    ()

let clamp_tensor ~low ~high tensor =
  let data = Rune.to_array tensor in
  let clipped = Array.copy data in
  let upper = Array.length clipped - 1 in
  for idx = 0 to upper do
    let lo = low.(idx) in
    let hi = high.(idx) in
    let v = clipped.(idx) in
    if v < lo then clipped.(idx) <- lo else if v > hi then clipped.(idx) <- hi
  done;
  Rune.create Rune.float32 (Rune.shape tensor) clipped

let clip_action (env : ('obs, Space.Box.element, 'render) Env.t) :
    ('obs, Space.Box.element, 'render) Env.t =
  let action_space = Env.action_space env in
  let low, high = Space.Box.bounds action_space in
  let element_count = Array.length low in
  if Array.length high <> element_count then
    invalid_arg "Wrapper.clip_action: mismatched low/high bounds";
  let relaxed_low =
    Array.init element_count (fun idx ->
        if Float.equal low.(idx) high.(idx) then low.(idx)
        else Float.neg_infinity)
  in
  let relaxed_high =
    Array.init element_count (fun idx ->
        if Float.equal low.(idx) high.(idx) then high.(idx) else Float.infinity)
  in
  let relaxed_space = Space.Box.create ~low:relaxed_low ~high:relaxed_high in
  map_action ~action_space:relaxed_space
    ~f:(fun action -> clamp_tensor ~low ~high action)
    env

let clip_observation (env : (Space.Box.element, 'act, 'render) Env.t) :
    (Space.Box.element, 'act, 'render) Env.t =
  let observation_space = Env.observation_space env in
  let low, high = Space.Box.bounds observation_space in
  let element_count = Array.length low in
  if Array.length high <> element_count then
    invalid_arg "Wrapper.clip_observation: mismatched low/high bounds";
  map_observation ~observation_space
    ~f:(fun observation info -> (clamp_tensor ~low ~high observation, info))
    env

let time_limit ~max_episode_steps env =
  if max_episode_steps <= 0 then
    invalid_arg "Wrapper.time_limit: max_episode_steps must be positive";
  let steps = ref 0 in
  let add_time_limit_info info elapsed_steps =
    info
    |> Info.set "time_limit.truncated" (Info.bool true)
    |> Info.set "time_limit.elapsed_steps" (Info.int elapsed_steps)
  in
  Env.create
    ?id:(derive_id env "/TimeLimit")
    ~metadata:(inherit_metadata env) ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () ->
      steps := 0;
      Env.reset env ?options ())
    ~step:(fun _wrapper action ->
      incr steps;
      let transition = Env.step env action in
      if transition.terminated || transition.truncated then (
        steps := 0;
        transition)
      else if !steps >= max_episode_steps then (
        let info = add_time_limit_info transition.info !steps in
        steps := 0;
        { transition with truncated = true; info })
      else transition)
    ()

let with_metadata ~f env =
  let metadata = f (Env.metadata env) in
  Env.create ?id:(Env.id env) ~metadata ~rng:(Env.rng env)
    ~render:(fun _ -> Env.render env)
    ~close:(fun _ -> Env.close env)
    ~observation_space:(Env.observation_space env)
    ~action_space:(Env.action_space env)
    ~reset:(fun _wrapper ?options () -> Env.reset env ?options ())
    ~step:(fun _wrapper action -> Env.step env action)
    ()