package fehu
Install
dune-project
Dependency
Authors
Maintainers
Sources
sha256=8e277ed56615d388bc69c4333e43d1acd112b5f2d5d352e2453aef223ff59867
sha512=369eda6df6b84b08f92c8957954d107058fb8d3d8374082e074b56f3a139351b3ae6e3a99f2d4a4a2930dd950fd609593467e502368a13ad6217b571382da28c
doc/fehu.algorithms/Fehu_algorithms/Dqn/index.html
Module Fehu_algorithms.Dqn
Source
Dqn
algorithm implementation.
DQN (Deep Q-Network) is an off-policy value-based method that uses experience replay and target networks for stable training. It learns Q-values for discrete actions and selects actions greedily. See Dqn
for detailed documentation.
DQN (Deep Q-Network) algorithm.
DQN is an off-policy value-based algorithm that learns Q-values (expected returns) for state-action pairs using neural network function approximation. It combines Q-learning with experience replay and target networks for stable training.
Algorithm
DQN follows these steps:
- Collect transitions (s, a, r, s') using epsilon-greedy exploration
- Store transitions in an experience replay buffer
- Sample random minibatches from the buffer
- Update Q-network using TD targets from a frozen target network
- Periodically update target network by copying Q-network parameters
The algorithm minimizes the temporal difference error:
L = E[(Q(s,a) - (r + γ max_a' Q_target(s', a')))²]
Usage
Basic usage:
open Fehu
(* Create Q-network *)
let q_net = Kaun.Layer.sequential [
Kaun.Layer.linear ~in_features:4 ~out_features:64 ();
Kaun.Layer.relu ();
Kaun.Layer.linear ~in_features:64 ~out_features:2 ();
] in
(* Initialize agent *)
let agent = Dqn.create
~q_network:q_net
~n_actions:2
~rng:(Rune.Rng.key 42)
Dqn.{ default_config with batch_size = 64 }
in
(* Train *)
let agent = Dqn.learn agent ~env ~total_timesteps:100_000 () in
(* Use trained policy (greedy) *)
let action = Dqn.predict agent obs ~epsilon:0.0
Manual training loop:
(* Collect transition *)
let obs, _info = Env.reset env () in
let action = Dqn.predict agent obs ~epsilon:0.1 in
let transition = Env.step env action in
(* Store in buffer *)
Dqn.add_transition agent ~observation:obs ~action
~reward:transition.reward ~next_observation:transition.observation
~terminated:transition.terminated ~truncated:transition.truncated;
(* Update Q-network *)
let loss, avg_q = Dqn.update agent in
(* Periodically update target network *)
if episode mod 10 = 0 then Dqn.update_target_network agent
Key Features
- Experience Replay: Breaks correlation between consecutive samples by randomly sampling from a replay buffer of past transitions.
- Target Network: Uses a separate, periodically-updated network for computing TD targets, improving stability.
- Epsilon-Greedy Exploration: Balances exploration and exploitation with decaying epsilon parameter.
- Off-Policy: Can learn from any transitions, enabling experience reuse and data-efficient learning.
When to Use DQN
- Discrete action spaces (e.g., game controls, navigation)
- Environments where off-policy learning is beneficial
- Tasks requiring sample efficiency through experience replay
- Problems with deterministic or near-deterministic dynamics
For continuous action spaces, consider SAC or DDPG (coming soon).
type config = {
learning_rate : float;
(*Learning rate for Q-network optimizer (default: 0.001)
*)gamma : float;
(*Discount factor for future rewards (default: 0.99)
*)epsilon_start : float;
(*Initial exploration rate (default: 1.0 = fully random)
*)epsilon_end : float;
(*Final exploration rate (default: 0.01 = mostly greedy)
*)epsilon_decay : float;
(*Decay rate for epsilon. Epsilon decays as: epsilon_end + (epsilon_start
- epsilon_end) * exp(-timesteps / epsilon_decay) (default: 1000.0)
batch_size : int;
(*Number of transitions to sample per update (default: 32)
*)buffer_capacity : int;
(*Maximum size of replay buffer (default: 10_000)
*)target_update_freq : int;
(*Update target network every N episodes (default: 10)
*)
}
Configuration for DQN training.
Typical values:
learning_rate
: 0.0001-0.001 for stable convergencegamma
: 0.95-0.99 depending on episode lengthepsilon_start
: 1.0 to ensure initial explorationepsilon_end
: 0.01-0.1 for some exploration after convergenceepsilon_decay
: 500-2000 for gradual exploration-exploitation transitionbatch_size
: 32-256 depending on memory and computebuffer_capacity
: 10_000-1_000_000 depending on problem complexitytarget_update_freq
: 5-100 episodes for stability
Default configuration:
learning_rate = 0.001
gamma = 0.99
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 1000.0
batch_size = 32
buffer_capacity = 10_000
target_update_freq = 10
These defaults work reasonably for simple environments like CartPole or GridWorld. Adjust for your specific problem.
DQN agent state.
Encapsulates Q-network, target network, experience replay buffer, optimizer state, and training configuration. The agent maintains all state needed for training and inference.
Note: Observations must be float32 tensors and actions are int32 tensors (discrete actions).
type update_metrics = {
episode_return : float;
(*Total reward for the episode
*)episode_length : int;
(*Number of steps in the episode
*)epsilon : float;
(*Current exploration rate
*)avg_q_value : float;
(*Average Q-value from the batch
*)loss : float;
(*TD loss from the batch update
*)
}
Metrics from a training episode.
These metrics help monitor training progress:
episode_return
should increase over time as the policy improvesepisode_length
may increase or decrease depending on the taskepsilon
decays fromepsilon_start
toepsilon_end
avg_q_value
should stabilize as Q-values convergeloss
should decrease and stabilize (though may fluctuate)
create ~q_network ~n_actions ~rng config
creates a DQN agent.
Parameters:
q_network
: Kaun network that takes observations and outputs Q-values for each action. Should output shapebatch_size, n_actions
.n_actions
: Number of discrete actions in the environment.rng
: Random number generator key for initialization and exploration.config
: DQN configuration (learning rate, gamma, epsilon, etc).
The Q-network should be a standard feedforward network. For image observations, use convolutional layers. For vector observations, use fully-connected layers.
Example Q-network architectures:
(* For vector observations *)
let q_net =
Kaun.Layer.sequential
[
Kaun.Layer.linear ~in_features:4 ~out_features:128 ();
Kaun.Layer.relu ();
Kaun.Layer.linear ~in_features:128 ~out_features:128 ();
Kaun.Layer.relu ();
Kaun.Layer.linear ~in_features:128 ~out_features:n_actions ();
]
(* For image observations *)
let q_net =
Kaun.Layer.sequential
[
Kaun.Layer.conv2d ~in_channels:4 ~out_channels:32 ~kernel_size:8
~stride:4 ();
Kaun.Layer.relu ();
Kaun.Layer.conv2d ~in_channels:32 ~out_channels:64 ~kernel_size:4
~stride:2 ();
Kaun.Layer.relu ();
Kaun.Layer.flatten ();
Kaun.Layer.linear ~in_features:3136 ~out_features:512 ();
Kaun.Layer.relu ();
Kaun.Layer.linear ~in_features:512 ~out_features:n_actions ();
]
val predict :
t ->
(float, Bigarray.float32_elt) Rune.t ->
epsilon:float ->
(int32, Bigarray.int32_elt) Rune.t
predict agent obs ~epsilon
selects an action using epsilon-greedy policy.
With probability epsilon
, selects a random action (exploration). With probability 1 - epsilon
, selects the action with highest Q-value (exploitation).
Parameters:
agent
: DQN agent.obs
: Observation tensor of shapeobs_dim
orbatch_size, obs_dim
. Automatically handles batching if needed.epsilon
: Exploration rate in [0, 1]. Use 0.0 for fully greedy policy (no exploration), 1.0 for fully random policy.
Returns action as int32 scalar tensor.
Example:
(* During training with decaying exploration *)
let epsilon = compute_epsilon ~timesteps in
let action = Dqn.predict agent obs ~epsilon in
(* During evaluation (no exploration) *)
let action = Dqn.predict agent obs ~epsilon:0.0
val add_transition :
t ->
observation:(float, Bigarray.float32_elt) Rune.t ->
action:(int32, Bigarray.int32_elt) Rune.t ->
reward:float ->
next_observation:(float, Bigarray.float32_elt) Rune.t ->
terminated:bool ->
truncated:bool ->
unit
add_transition agent ~observation ~action ~reward ~next_observation ~terminated ~truncated
stores a transition in the replay buffer.
Parameters:
agent
: DQN agent.observation
: Current state observation (without batch dimension).action
: Action taken (int32 scalar).reward
: Immediate reward received.next_observation
: Resulting next state (without batch dimension).terminated
: Whether episode ended naturally (reached goal/failure state).truncated
: Whether episode was artificially truncated (timeout).
Transitions are stored in a circular buffer. When the buffer is full, oldest transitions are overwritten.
The distinction between terminated
and truncated
matters for bootstrapping: terminal states have value 0, while truncated states may have non-zero value.
update agent
performs a single gradient update on the Q-network.
Samples a batch from the replay buffer, computes TD targets using the target network, and updates the Q-network parameters using gradient descent on the TD error.
Returns (loss, avg_q_value)
where:
loss
: Mean squared TD error for the batchavg_q_value
: Average Q-value predicted for the batch
If the replay buffer has fewer samples than batch_size
, returns (0.0, 0.0)
without performing an update.
The TD target is computed as:
y = r + γ max_a' Q_target(s', a')
where Q_target is the frozen target network.
Call this function after each environment step during training.
update_target_network agent
updates the target network by copying Q-network parameters.
Should be called periodically (every config.target_update_freq
episodes) to keep the target network stable. Frequent updates can cause instability, while infrequent updates can slow learning.
Example:
if episode mod agent.config.target_update_freq = 0 then
Dqn.update_target_network agent
val learn :
t ->
env:
((float, Bigarray.float32_elt) Rune.t,
(int32, Bigarray.int32_elt) Rune.t,
'render)
Fehu.Env.t ->
total_timesteps:int ->
?callback:(episode:int -> metrics:update_metrics -> bool) ->
?warmup_steps:int ->
unit ->
t
learn agent ~env ~total_timesteps ~callback ~warmup_steps ()
trains the DQN agent on an environment.
Runs episodes until total_timesteps
is reached. Each episode: 1. Resets environment 2. Collects transitions using epsilon-greedy policy 3. Stores transitions in replay buffer 4. Samples batches and updates Q-network 5. Periodically updates target network
Parameters:
agent
: DQN agent to train.env
: Environment to train on.total_timesteps
: Total number of environment steps to train for.callback
: Optional callback called after each episode with episode number and metrics. Returnfalse
to stop training early. Default always returnstrue
.warmup_steps
: Number of initial steps to collect before starting training (filling replay buffer). Default:batch_size
.
Returns the trained agent.
Example with callback:
let agent =
Dqn.learn agent ~env ~total_timesteps:100_000
~callback:(fun ~episode ~metrics ->
if episode mod 10 = 0 then
Printf.printf "Episode %d: Return = %.2f, Loss = %.4f\n" episode
metrics.episode_return metrics.loss;
true (* continue training *))
()
The warmup phase collects random experiences before training begins. This ensures the replay buffer has diverse samples before Q-network updates start.
save agent path
saves agent to disk.
Not yet implemented. Will serialize Q-network parameters, target network, optimizer state, and configuration.