package fehu

  1. Overview
  2. Docs
Reinforcement learning framework for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c

doc/src/fehu.envs/cartpole.ml.html

Source file cartpole.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
open Fehu

type observation = (float, Rune.float32_elt) Rune.t
type action = (int32, Rune.int32_elt) Rune.t
type render = string

type state = {
  mutable x : float; (* cart position *)
  mutable x_dot : float; (* cart velocity *)
  mutable theta : float; (* pole angle *)
  mutable theta_dot : float; (* pole angular velocity *)
  mutable steps : int;
  rng : Rune.Rng.key ref;
}

(* Environment parameters matching Gymnasium CartPole-v1 *)
let gravity = 9.8
let masscart = 1.0
let masspole = 0.1
let total_mass = masscart +. masspole
let length = 0.5 (* half the pole's length *)
let polemass_length = masspole *. length
let force_mag = 10.0
let tau = 0.02 (* time step *)

(* Thresholds for episode termination *)
let theta_threshold_radians = 12. *. Float.pi /. 180.
let x_threshold = 2.4

let observation_space =
  Space.Box.create
    ~low:
      [|
        -4.8;
        -.Float.max_float;
        -.theta_threshold_radians *. 2.;
        -.Float.max_float;
      |]
    ~high:
      [| 4.8; Float.max_float; theta_threshold_radians *. 2.; Float.max_float |]

let action_space = Space.Discrete.create 2

let metadata =
  Metadata.default
  |> Metadata.add_render_mode "ansi"
  |> Metadata.with_description (Some "Classic cart-pole balancing problem")
  |> Metadata.add_author "Fehu"
  |> Metadata.with_version (Some "0.1.0")

let reset _env ?options:_ () state =
  (* Reset to small random values around 0 *)
  let keys = Rune.Rng.split !(state.rng) ~n:5 in
  state.rng := keys.(0);

  (* Uniform random values in [-0.05, 0.05] *)
  let random_state i =
    let r = Rune.Rng.uniform keys.(i + 1) Rune.float32 [| 1 |] in
    let v = (Rune.to_array r).(0) in
    (v -. 0.5) *. 0.1
  in

  state.x <- random_state 0;
  state.x_dot <- random_state 1;
  state.theta <- random_state 2;
  state.theta_dot <- random_state 3;
  state.steps <- 0;

  let obs =
    Rune.create Rune.float32 [| 4 |]
      [| state.x; state.x_dot; state.theta; state.theta_dot |]
  in
  (obs, Info.empty)

let step _env action state =
  let action_value =
    let arr : Int32.t array = Rune.to_array action in
    Int32.to_int arr.(0)
  in

  let force = if action_value = 1 then force_mag else -.force_mag in

  let costheta = cos state.theta in
  let sintheta = sin state.theta in

  (* Equations from Gymnasium CartPole-v1 *)
  let temp =
    (force
    +. (polemass_length *. state.theta_dot *. state.theta_dot *. sintheta))
    /. total_mass
  in
  let thetaacc =
    ((gravity *. sintheta) -. (costheta *. temp))
    /. (length
       *. ((4.0 /. 3.0) -. (masspole *. costheta *. costheta /. total_mass)))
  in
  let xacc = temp -. (polemass_length *. thetaacc *. costheta /. total_mass) in

  (* Euler integration *)
  state.x <- state.x +. (tau *. state.x_dot);
  state.x_dot <- state.x_dot +. (tau *. xacc);
  state.theta <- state.theta +. (tau *. state.theta_dot);
  state.theta_dot <- state.theta_dot +. (tau *. thetaacc);
  state.steps <- state.steps + 1;

  let terminated =
    state.x < -.x_threshold || state.x > x_threshold
    || state.theta < -.theta_threshold_radians
    || state.theta > theta_threshold_radians
  in

  let truncated = state.steps >= 500 in
  let reward = if terminated then 0.0 else 1.0 in

  let obs =
    Rune.create Rune.float32 [| 4 |]
      [| state.x; state.x_dot; state.theta; state.theta_dot |]
  in

  let info = Info.set "steps" (Info.int state.steps) Info.empty in
  Env.transition ~observation:obs ~reward ~terminated ~truncated ~info ()

let render state =
  Printf.sprintf
    "CartPole: x=%.3f, x_dot=%.3f, theta=%.3f°, theta_dot=%.3f, steps=%d"
    state.x state.x_dot
    (state.theta *. 180. /. Float.pi)
    state.theta_dot state.steps

let make ~rng () =
  let state =
    {
      x = 0.0;
      x_dot = 0.0;
      theta = 0.0;
      theta_dot = 0.0;
      steps = 0;
      rng = ref rng;
    }
  in
  Env.create ~id:"CartPole-v1" ~metadata ~rng ~observation_space ~action_space
    ~reset:(fun env ?options () -> reset env ?options () state)
    ~step:(fun env action -> step env action state)
    ~render:(fun _ -> Some (render state))
    ~close:(fun _ -> ())
    ()