package fehu
Reinforcement learning framework for OCaml
Install
dune-project
Dependency
Authors
Maintainers
Sources
raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c
doc/src/fehu.envs/grid_world.ml.html
Source file grid_world.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
open Fehu type observation = (int32, Rune.int32_elt) Rune.t type action = (int32, Rune.int32_elt) Rune.t type render = string type state = { mutable position : int * int; mutable steps : int } let grid_size = 5 let observation_space = Space.Multi_discrete.create [| grid_size; grid_size |] let action_space = Space.Discrete.create 4 let metadata = Metadata.default |> Metadata.add_render_mode "ansi" |> Metadata.with_description (Some "Simple 5x5 grid world with goal and obstacles") |> Metadata.add_author "Fehu New" |> Metadata.with_version (Some "0.1.0") let is_goal (r, c) = r = 4 && c = 4 let is_obstacle (r, c) = r = 2 && c = 2 let is_valid_pos (r, c) = r >= 0 && r < grid_size && c >= 0 && c < grid_size && not (is_obstacle (r, c)) let reset _env ?options:_ () state = state.position <- (0, 0); state.steps <- 0; (Rune.create Rune.int32 [| 2 |] [| 0l; 0l |], Info.empty) let step _env action state = let action_value = let tensor = Rune.reshape [| 1 |] action in let arr : Int32.t array = Rune.to_array tensor in Int32.to_int arr.(0) in let row, col = state.position in let candidate = match action_value with | 0 -> (row - 1, col) | 1 -> (row + 1, col) | 2 -> (row, col - 1) | 3 -> (row, col + 1) | _ -> (row, col) in let next_pos = if is_valid_pos candidate then candidate else state.position in state.position <- next_pos; state.steps <- state.steps + 1; let terminated = is_goal next_pos in let truncated = state.steps >= 200 in let reward = if terminated then 10.0 else -1.0 in let info = Info.set "steps" (Info.int state.steps) Info.empty in let r, c = next_pos in let observation = Rune.create Rune.int32 [| 2 |] [| Int32.of_int r; Int32.of_int c |] in Env.transition ~observation ~reward ~terminated ~truncated ~info () let render state = let buffer = Bytes.make (grid_size * grid_size) '.' in let row, col = state.position in let index = (row * grid_size) + col in let () = Bytes.set buffer index 'A' in let goal_index = ((grid_size - 1) * grid_size) + (grid_size - 1) in Bytes.set buffer goal_index 'G'; let obstacle_index = (2 * grid_size) + 2 in Bytes.set buffer obstacle_index '#'; let rows = List.init grid_size (fun r -> let start = r * grid_size in Bytes.sub_string buffer start grid_size) in Format.asprintf "Position: (%d, %d)@.%a" row col (Format.pp_print_list ~pp_sep:(fun fmt () -> Format.fprintf fmt "@.") Format.pp_print_string) rows let make ~rng () = let state = { position = (0, 0); steps = 0 } in Env.create ~id:"GridWorld-v0" ~metadata ~rng ~observation_space ~action_space ~reset:(fun env ?options () -> reset env ?options () state) ~step:(fun env action -> step env action state) ~render:(fun _ -> Some (render state)) ~close:(fun _ -> ()) ()
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>