package fehu
Reinforcement learning framework for OCaml
Install
dune-project
Dependency
Authors
Maintainers
Sources
raven-1.0.0.alpha1.tbz
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c
doc/src/fehu.algorithms/reinforce.ml.html
Source file reinforce.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455
open Kaun type config = { learning_rate : float; gamma : float; use_baseline : bool; reward_scale : float; entropy_coef : float; max_episode_steps : int; } let default_config = { learning_rate = 0.001; gamma = 0.99; use_baseline = false; reward_scale = 1.0; entropy_coef = 0.01; max_episode_steps = 1000; } type t = { policy_network : module_; mutable policy_params : Rune.float32_elt params; baseline_network : module_ option; mutable baseline_params : Rune.float32_elt params option; policy_optimizer : Rune.float32_elt Optimizer.gradient_transformation; mutable policy_opt_state : Rune.float32_elt Optimizer.opt_state; baseline_optimizer : Rune.float32_elt Optimizer.gradient_transformation option; mutable baseline_opt_state : Rune.float32_elt Optimizer.opt_state option; mutable rng : Rune.Rng.key; n_actions : int; config : config; } type update_metrics = { episode_return : float; episode_length : int; avg_entropy : float; avg_log_prob : float; adv_mean : float; adv_std : float; value_loss : float option; } let create ~policy_network ?baseline_network ~n_actions ~rng config = if config.use_baseline && Option.is_none baseline_network then invalid_arg "Reinforce.create: baseline_network required when use_baseline = true"; let keys = Rune.Rng.split ~n:2 rng in let policy_params = init policy_network ~rngs:keys.(0) ~dtype:Rune.float32 in let policy_optimizer = Optimizer.adam ~lr:config.learning_rate () in let policy_opt_state = policy_optimizer.init policy_params in let ( baseline_network_state, baseline_params, baseline_optimizer, baseline_opt_state ) = match baseline_network with | Some net when config.use_baseline -> let params = init net ~rngs:keys.(1) ~dtype:Rune.float32 in let opt = Optimizer.adam ~lr:config.learning_rate () in let opt_state = opt.init params in (Some net, Some params, Some opt, Some opt_state) | _ -> (None, None, None, None) in { policy_network; policy_params; baseline_network = baseline_network_state; baseline_params; policy_optimizer; policy_opt_state; baseline_optimizer; baseline_opt_state; rng = keys.(0); n_actions; config; } let predict t obs ~training = (* Add batch dimension if needed: [features] -> [1, features] *) let obs_shape = Rune.shape obs in let obs_batched = if Array.length obs_shape = 1 then (* No batch dimension, add it *) let features = obs_shape.(0) in Rune.reshape [| 1; features |] obs else (* Already has batch dimension *) obs in let logits = apply t.policy_network t.policy_params ~training obs_batched in if training then ( (* Sample from policy distribution using Rune RNG *) let probs = Rune.softmax logits ~axes:[ -1 ] in let probs_array = Rune.to_array probs in (* Split RNG and sample *) let keys = Rune.Rng.split t.rng ~n:2 in t.rng <- keys.(0); let sample_rng = keys.(1) in let uniform_sample = Rune.Rng.uniform sample_rng Rune.float32 [| 1 |] in let r = (Rune.to_array uniform_sample).(0) in let rec sample_idx i cumsum = if i >= Array.length probs_array - 1 then i else if r <= cumsum +. probs_array.(i) then i else sample_idx (i + 1) (cumsum +. probs_array.(i)) in let action_idx = sample_idx 0 0.0 in let action = Rune.scalar Rune.int32 (Int32.of_int action_idx) in (* Compute log_softmax manually *) let max_logits = Rune.max logits ~axes:[ -1 ] ~keepdims:true in let exp_logits = Rune.exp (Rune.sub logits max_logits) in let sum_exp = Rune.sum exp_logits ~axes:[ -1 ] ~keepdims:true in let log_probs = Rune.sub logits (Rune.add max_logits (Rune.log sum_exp)) in let log_prob_array = Rune.to_array log_probs in let log_prob = log_prob_array.(action_idx) in (action, log_prob)) else (* Greedy selection *) let action_idx = Rune.argmax logits ~axis:(-1) ~keepdims:false in let action = Rune.cast Rune.int32 action_idx in (action, 0.0) let predict_value t obs = match (t.baseline_network, t.baseline_params) with | Some net, Some params -> (* Add batch dimension if needed *) let obs_shape = Rune.shape obs in let obs_batched = if Array.length obs_shape = 1 then let features = obs_shape.(0) in Rune.reshape [| 1; features |] obs else obs in let value = apply net params ~training:false obs_batched in let value_array = Rune.to_array value in value_array.(0) | _ -> 0.0 let update t trajectory = let open Fehu in (* Compute returns using Monte Carlo *) let compute_returns ~rewards ~dones ~gamma = let n = Array.length rewards in let returns = Array.make n 0.0 in let running_return = ref 0.0 in for i = n - 1 downto 0 do if dones.(i) then running_return := 0.0; running_return := rewards.(i) +. (gamma *. !running_return); returns.(i) <- !running_return done; returns in let returns_raw = compute_returns ~rewards:trajectory.Trajectory.rewards ~dones:trajectory.Trajectory.terminateds ~gamma:t.config.gamma in let returns = Array.map (fun r -> r *. t.config.reward_scale) returns_raw in (* Compute advantages *) let advantages = if t.config.use_baseline then match trajectory.Trajectory.values with | Some values -> Array.mapi (fun i r -> r -. values.(i)) returns | None -> returns else returns in let n_steps = Array.length advantages in let steps_f = float_of_int n_steps in (* Compute advantage statistics *) let adv_mean = if n_steps = 0 then 0.0 else Array.fold_left ( +. ) 0.0 advantages /. steps_f in let adv_var = if n_steps = 0 then 0.0 else Array.fold_left (fun acc a -> let diff = a -. adv_mean in acc +. (diff *. diff)) 0.0 advantages /. steps_f in let adv_std = sqrt adv_var in (* Normalize advantages *) let advantages_norm = if n_steps = 0 then [||] else let denom = if adv_std < 1e-6 then 1.0 else adv_std in Array.map (fun a -> (a -. adv_mean) /. denom) advantages in let entropy_acc = ref 0.0 in let log_prob_acc = ref 0.0 in (* Policy gradient loss *) let policy_loss_grad params = let total_loss = ref (Rune.scalar Rune.float32 0.0) in for i = 0 to Array.length trajectory.Trajectory.observations - 1 do let obs = trajectory.Trajectory.observations.(i) in let action = trajectory.Trajectory.actions.(i) in let advantage = advantages_norm.(i) in (* Add batch dimension if needed: [features] -> [1, features] *) let obs_shape = Rune.shape obs in let obs_batched = if Array.length obs_shape = 1 then let features = obs_shape.(0) in Rune.reshape [| 1; features |] obs else obs in let logits = apply t.policy_network params ~training:true obs_batched in (* Compute log_softmax *) let max_logits = Rune.max logits ~axes:[ -1 ] ~keepdims:true in let exp_logits = Rune.exp (Rune.sub logits max_logits) in let sum_exp = Rune.sum exp_logits ~axes:[ -1 ] ~keepdims:true in let log_probs_pred = Rune.sub logits (Rune.add max_logits (Rune.log sum_exp)) in (* Compute entropy for logging *) let probs = Rune.softmax logits ~axes:[ -1 ] in let probs_arr = Rune.to_array (Rune.reshape [| t.n_actions |] probs) in let entropy = Array.fold_left (fun acc p -> if p > 0. then acc -. (p *. log p) else acc) 0.0 probs_arr in entropy_acc := !entropy_acc +. entropy; (* Select log_prob for taken action *) let log_probs_flat = Rune.reshape [| t.n_actions |] log_probs_pred in let indices = Rune.reshape [| 1 |] (Rune.cast Rune.int32 action) in let selected = Rune.take indices log_probs_flat |> Rune.reshape [||] in let log_prob_float = (Rune.to_array selected).(0) in log_prob_acc := !log_prob_acc +. log_prob_float; (* Policy gradient: -log π(a|s) × advantage *) let loss = Rune.mul_s (Rune.neg selected) advantage in total_loss := Rune.add !total_loss loss; (* Entropy regularization *) if t.config.entropy_coef <> 0. then let entropy_scalar = Rune.sum (Rune.mul probs log_probs_pred) ~axes:[ -1 ] ~keepdims:false |> Rune.reshape [||] in let entropy_term = Rune.mul_s entropy_scalar (-.t.config.entropy_coef) in total_loss := Rune.add !total_loss entropy_term done; let avg_loss = Rune.div_s !total_loss (float_of_int (Array.length trajectory.Trajectory.observations)) in avg_loss in (* Update policy *) let _policy_loss, policy_grads = value_and_grad policy_loss_grad t.policy_params in let policy_updates, new_policy_opt_state = t.policy_optimizer.update t.policy_opt_state t.policy_params policy_grads in t.policy_params <- Optimizer.apply_updates t.policy_params policy_updates; t.policy_opt_state <- new_policy_opt_state; (* Update baseline if present *) let value_loss_acc = ref 0.0 in (if t.config.use_baseline then match ( t.baseline_network, t.baseline_params, t.baseline_optimizer, t.baseline_opt_state ) with | Some net, Some params, Some opt, Some opt_state -> let baseline_loss_grad params = let total_loss = ref (Rune.scalar Rune.float32 0.0) in for i = 0 to Array.length trajectory.Trajectory.observations - 1 do let obs = trajectory.Trajectory.observations.(i) in let return = returns.(i) in (* Add batch dimension if needed: [features] -> [1, features] *) let obs_shape = Rune.shape obs in let obs_batched = if Array.length obs_shape = 1 then let features = obs_shape.(0) in Rune.reshape [| 1; features |] obs else obs in let value_pred = apply net params ~training:true obs_batched in let value = Rune.reshape [||] value_pred in let value_float = (Rune.to_array value).(0) in let diff = value_float -. return in value_loss_acc := !value_loss_acc +. (diff *. diff); let target = Rune.scalar Rune.float32 return in let loss = Rune.square (Rune.sub value target) in total_loss := Rune.add !total_loss loss done; let avg_loss = Rune.div_s !total_loss (float_of_int (Array.length trajectory.Trajectory.observations)) in avg_loss in let _, baseline_grads = value_and_grad baseline_loss_grad params in let baseline_updates, new_baseline_opt_state = opt.update opt_state params baseline_grads in t.baseline_params <- Some (Optimizer.apply_updates params baseline_updates); t.baseline_opt_state <- Some new_baseline_opt_state | _ -> ()); (* Compute metrics *) let episode_return = if n_steps > 0 then returns_raw.(0) else 0.0 in let avg_entropy = if n_steps = 0 then 0.0 else !entropy_acc /. steps_f in let avg_log_prob = if n_steps = 0 then 0.0 else !log_prob_acc /. steps_f in let value_loss_avg = if n_steps = 0 then None else if t.config.use_baseline then Some (!value_loss_acc /. steps_f) else None in let metrics = { episode_return; episode_length = n_steps; avg_entropy; avg_log_prob; adv_mean; adv_std; value_loss = value_loss_avg; } in (t, metrics) let learn t ~env ~total_timesteps ?(callback = fun ~iteration:_ ~metrics:_ -> true) () = let open Fehu in let timesteps = ref 0 in let iteration = ref 0 in while !timesteps < total_timesteps do (* Collect episode *) let observations = ref [] in let actions = ref [] in let rewards = ref [] in let terminateds = ref [] in let truncateds = ref [] in let log_probs = ref [] in let values = ref [] in let obs, _info = Env.reset env () in let current_obs = ref obs in let steps = ref 0 in let done_flag = ref false in while !steps < t.config.max_episode_steps && not !done_flag do let action, log_prob = predict t !current_obs ~training:true in let value = if t.config.use_baseline then predict_value t !current_obs else 0.0 in observations := !current_obs :: !observations; actions := action :: !actions; log_probs := log_prob :: !log_probs; values := value :: !values; let transition = Env.step env action in rewards := transition.Env.reward :: !rewards; terminateds := transition.Env.terminated :: !terminateds; truncateds := transition.Env.truncated :: !truncateds; current_obs := transition.Env.observation; steps := !steps + 1; done_flag := transition.Env.terminated || transition.Env.truncated done; timesteps := !timesteps + !steps; (* Create trajectory *) let log_probs_array = Array.of_list (List.rev !log_probs) in let values_array = Array.of_list (List.rev !values) in let trajectory = Trajectory.create ~observations:(Array.of_list (List.rev !observations)) ~actions:(Array.of_list (List.rev !actions)) ~rewards:(Array.of_list (List.rev !rewards)) ~terminateds:(Array.of_list (List.rev !terminateds)) ~truncateds:(Array.of_list (List.rev !truncateds)) ~log_probs:log_probs_array ?values:(if t.config.use_baseline then Some values_array else None) () in (* Update *) let _t, metrics = update t trajectory in iteration := !iteration + 1; (* Call callback *) let continue = callback ~iteration:!iteration ~metrics in if not continue then timesteps := total_timesteps done; t let save _t _path = (* TODO: Implement serialization *) failwith "Reinforce.save: not yet implemented" let load _path = (* TODO: Implement deserialization *) failwith "Reinforce.load: not yet implemented"
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>