package fehu

  1. Overview
  2. Docs
Reinforcement learning framework for OCaml

Install

dune-project
 Dependency

Authors

Maintainers

Sources

raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c

doc/src/fehu.envs/grid_world.ml.html

Source file grid_world.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
open Fehu

type observation = (int32, Rune.int32_elt) Rune.t
type action = (int32, Rune.int32_elt) Rune.t
type render = string
type state = { mutable position : int * int; mutable steps : int }

let grid_size = 5
let observation_space = Space.Multi_discrete.create [| grid_size; grid_size |]
let action_space = Space.Discrete.create 4

let metadata =
  Metadata.default
  |> Metadata.add_render_mode "ansi"
  |> Metadata.with_description
       (Some "Simple 5x5 grid world with goal and obstacles")
  |> Metadata.add_author "Fehu New"
  |> Metadata.with_version (Some "0.1.0")

let is_goal (r, c) = r = 4 && c = 4
let is_obstacle (r, c) = r = 2 && c = 2

let is_valid_pos (r, c) =
  r >= 0 && r < grid_size && c >= 0 && c < grid_size && not (is_obstacle (r, c))

let reset _env ?options:_ () state =
  state.position <- (0, 0);
  state.steps <- 0;
  (Rune.create Rune.int32 [| 2 |] [| 0l; 0l |], Info.empty)

let step _env action state =
  let action_value =
    let tensor = Rune.reshape [| 1 |] action in
    let arr : Int32.t array = Rune.to_array tensor in
    Int32.to_int arr.(0)
  in
  let row, col = state.position in
  let candidate =
    match action_value with
    | 0 -> (row - 1, col)
    | 1 -> (row + 1, col)
    | 2 -> (row, col - 1)
    | 3 -> (row, col + 1)
    | _ -> (row, col)
  in
  let next_pos = if is_valid_pos candidate then candidate else state.position in
  state.position <- next_pos;
  state.steps <- state.steps + 1;
  let terminated = is_goal next_pos in
  let truncated = state.steps >= 200 in
  let reward = if terminated then 10.0 else -1.0 in
  let info = Info.set "steps" (Info.int state.steps) Info.empty in
  let r, c = next_pos in
  let observation =
    Rune.create Rune.int32 [| 2 |] [| Int32.of_int r; Int32.of_int c |]
  in
  Env.transition ~observation ~reward ~terminated ~truncated ~info ()

let render state =
  let buffer = Bytes.make (grid_size * grid_size) '.' in
  let row, col = state.position in
  let index = (row * grid_size) + col in
  let () = Bytes.set buffer index 'A' in
  let goal_index = ((grid_size - 1) * grid_size) + (grid_size - 1) in
  Bytes.set buffer goal_index 'G';
  let obstacle_index = (2 * grid_size) + 2 in
  Bytes.set buffer obstacle_index '#';
  let rows =
    List.init grid_size (fun r ->
        let start = r * grid_size in
        Bytes.sub_string buffer start grid_size)
  in
  Format.asprintf "Position: (%d, %d)@.%a" row col
    (Format.pp_print_list
       ~pp_sep:(fun fmt () -> Format.fprintf fmt "@.")
       Format.pp_print_string)
    rows

let make ~rng () =
  let state = { position = (0, 0); steps = 0 } in
  Env.create ~id:"GridWorld-v0" ~metadata ~rng ~observation_space ~action_space
    ~reset:(fun env ?options () -> reset env ?options () state)
    ~step:(fun env action -> step env action state)
    ~render:(fun _ -> Some (render state))
    ~close:(fun _ -> ())
    ()