package fehu

  1. Overview
  2. Docs

Source file cartpole.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
open Fehu

type observation = (float, Rune.float32_elt) Rune.t
type action = (int32, Rune.int32_elt) Rune.t
type render = string

type state = {
  mutable x : float; (* cart position *)
  mutable x_dot : float; (* cart velocity *)
  mutable theta : float; (* pole angle *)
  mutable theta_dot : float; (* pole angular velocity *)
  mutable steps : int;
  rng : Rune.Rng.key ref;
}

(* Environment parameters matching Gymnasium CartPole-v1 *)
let gravity = 9.8
let masscart = 1.0
let masspole = 0.1
let total_mass = masscart +. masspole
let length = 0.5 (* half the pole's length *)
let polemass_length = masspole *. length
let force_mag = 10.0
let tau = 0.02 (* time step *)

(* Thresholds for episode termination *)
let theta_threshold_radians = 12. *. Float.pi /. 180.
let x_threshold = 2.4

let observation_space =
  Space.Box.create
    ~low:
      [|
        -4.8;
        -.Float.max_float;
        -.theta_threshold_radians *. 2.;
        -.Float.max_float;
      |]
    ~high:
      [| 4.8; Float.max_float; theta_threshold_radians *. 2.; Float.max_float |]

let action_space = Space.Discrete.create 2

let metadata =
  Metadata.default
  |> Metadata.add_render_mode "ansi"
  |> Metadata.with_description (Some "Classic cart-pole balancing problem")
  |> Metadata.add_author "Fehu"
  |> Metadata.with_version (Some "0.1.0")

let reset _env ?options:_ () state =
  (* Reset to small random values around 0 *)
  let keys = Rune.Rng.split !(state.rng) ~n:5 in
  state.rng := keys.(0);

  (* Uniform random values in [-0.05, 0.05] *)
  let random_state i =
    let r = Rune.Rng.uniform keys.(i + 1) Rune.float32 [| 1 |] in
    let v = (Rune.to_array r).(0) in
    (v -. 0.5) *. 0.1
  in

  state.x <- random_state 0;
  state.x_dot <- random_state 1;
  state.theta <- random_state 2;
  state.theta_dot <- random_state 3;
  state.steps <- 0;

  let obs =
    Rune.create Rune.float32 [| 4 |]
      [| state.x; state.x_dot; state.theta; state.theta_dot |]
  in
  (obs, Info.empty)

let step _env action state =
  let action_value =
    let arr : Int32.t array = Rune.to_array action in
    Int32.to_int arr.(0)
  in

  let force = if action_value = 1 then force_mag else -.force_mag in

  let costheta = cos state.theta in
  let sintheta = sin state.theta in

  (* Equations from Gymnasium CartPole-v1 *)
  let temp =
    (force
    +. (polemass_length *. state.theta_dot *. state.theta_dot *. sintheta))
    /. total_mass
  in
  let thetaacc =
    ((gravity *. sintheta) -. (costheta *. temp))
    /. (length
       *. ((4.0 /. 3.0) -. (masspole *. costheta *. costheta /. total_mass)))
  in
  let xacc = temp -. (polemass_length *. thetaacc *. costheta /. total_mass) in

  (* Euler integration *)
  state.x <- state.x +. (tau *. state.x_dot);
  state.x_dot <- state.x_dot +. (tau *. xacc);
  state.theta <- state.theta +. (tau *. state.theta_dot);
  state.theta_dot <- state.theta_dot +. (tau *. thetaacc);
  state.steps <- state.steps + 1;

  let terminated =
    state.x < -.x_threshold || state.x > x_threshold
    || state.theta < -.theta_threshold_radians
    || state.theta > theta_threshold_radians
  in

  let truncated = state.steps >= 500 in
  let reward = if terminated then 0.0 else 1.0 in

  let obs =
    Rune.create Rune.float32 [| 4 |]
      [| state.x; state.x_dot; state.theta; state.theta_dot |]
  in

  let info = Info.set "steps" (Info.int state.steps) Info.empty in
  Env.transition ~observation:obs ~reward ~terminated ~truncated ~info ()

let render state =
  Printf.sprintf
    "CartPole: x=%.3f, x_dot=%.3f, theta=%.3f°, theta_dot=%.3f, steps=%d"
    state.x state.x_dot
    (state.theta *. 180. /. Float.pi)
    state.theta_dot state.steps

let make ~rng () =
  let state =
    {
      x = 0.0;
      x_dot = 0.0;
      theta = 0.0;
      theta_dot = 0.0;
      steps = 0;
      rng = ref rng;
    }
  in
  Env.create ~id:"CartPole-v1" ~metadata ~rng ~observation_space ~action_space
    ~reset:(fun env ?options () -> reset env ?options () state)
    ~step:(fun env action -> step env action state)
    ~render:(fun _ -> Some (render state))
    ~close:(fun _ -> ())
    ()