package fehu

  1. Overview
  2. Docs
Reinforcement learning for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha3.tbz
sha256=96d35ce03dfbebd2313657273e24c2e2d20f9e6c7825b8518b69bd1d6ed5870f
sha512=90c5053731d4108f37c19430e45456063e872b04b8a1bbad064c356e1b18e69222de8bfcf4ec14757e71f18164ec6e4630ba770dbcb1291665de5418827d1465

doc/src/fehu.envs/cartpole.ml.html

Source file cartpole.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
(*---------------------------------------------------------------------------
  Copyright (c) 2026 The Raven authors. All rights reserved.
  SPDX-License-Identifier: ISC
  ---------------------------------------------------------------------------*)

open Fehu

type obs = (float, Nx.float32_elt) Nx.t
type act = (int32, Nx.int32_elt) Nx.t
type render = string

(* Physics constants matching Gymnasium CartPole-v1 *)

let gravity = 9.8
let masscart = 1.0
let masspole = 0.1
let total_mass = masscart +. masspole
let half_pole_length = 0.5
let polemass_length = masspole *. half_pole_length
let force_mag = 10.0
let tau = 0.02

(* Termination thresholds *)

let theta_threshold = 12. *. Float.pi /. 180.
let x_threshold = 2.4
let max_steps = 500

(* Float32-representable large bound for "unbounded" dimensions *)
let f32_max = 3.4028235e38

let observation_space =
  Space.Box.create
    ~low:[| -4.8; -.f32_max; -.theta_threshold *. 2.; -.f32_max |]
    ~high:[| 4.8; f32_max; theta_threshold *. 2.; f32_max |]

let action_space = Space.Discrete.create 2

let make_obs x x_dot theta theta_dot =
  Nx.create Nx.float32 [| 4 |] [| x; x_dot; theta; theta_dot |]

let make ?render_mode () =
  let x = ref 0.0 in
  let x_dot = ref 0.0 in
  let theta = ref 0.0 in
  let theta_dot = ref 0.0 in
  let steps = ref 0 in
  let reset _env ?options:_ () =
    let random_state () =
      let r = Nx.rand Nx.float32 [| 1 |] in
      let v = (Nx.to_array r).(0) in
      (v -. 0.5) *. 0.1
    in
    x := random_state ();
    x_dot := random_state ();
    theta := random_state ();
    theta_dot := random_state ();
    steps := 0;
    (make_obs !x !x_dot !theta !theta_dot, Info.empty)
  in
  let step _env action =
    let force =
      if Space.Discrete.to_int action = 1 then force_mag else -.force_mag
    in
    let costheta = cos !theta in
    let sintheta = sin !theta in
    let temp =
      (force +. (polemass_length *. !theta_dot *. !theta_dot *. sintheta))
      /. total_mass
    in
    let thetaacc =
      ((gravity *. sintheta) -. (costheta *. temp))
      /. (half_pole_length
         *. ((4.0 /. 3.0) -. (masspole *. costheta *. costheta /. total_mass)))
    in
    let xacc =
      temp -. (polemass_length *. thetaacc *. costheta /. total_mass)
    in
    x := !x +. (tau *. !x_dot);
    x_dot := !x_dot +. (tau *. xacc);
    theta := !theta +. (tau *. !theta_dot);
    theta_dot := !theta_dot +. (tau *. thetaacc);
    incr steps;
    let terminated =
      !x < -.x_threshold || !x > x_threshold || !theta < -.theta_threshold
      || !theta > theta_threshold
    in
    let truncated = (not terminated) && !steps >= max_steps in
    let reward = if terminated then 0.0 else 1.0 in
    let info = Info.set "steps" (Info.int !steps) Info.empty in
    Env.step_result
      ~observation:(make_obs !x !x_dot !theta !theta_dot)
      ~reward ~terminated ~truncated ~info ()
  in
  let render () =
    Some
      (Printf.sprintf
         "CartPole: x=%.3f, x_dot=%.3f, theta=%.3f\xc2\xb0, theta_dot=%.3f, \
          steps=%d"
         !x !x_dot
         (!theta *. 180. /. Float.pi)
         !theta_dot !steps)
  in
  Env.create ?render_mode ~render_modes:[ "ansi" ] ~id:"CartPole-v1"
    ~observation_space ~action_space ~reset ~step ~render ()