DQN (Deep Q-Network) is an off-policy value-based method that uses experience replay and target networks for stable training. It learns Q-values for discrete actions and selects actions greedily. See Dqn for detailed documentation.
DQN (Deep Q-Network) algorithm.
DQN is an off-policy value-based algorithm that learns Q-values (expected returns) for state-action pairs using neural network function approximation. It combines Q-learning with experience replay and target networks for stable training.
Algorithm
DQN follows these steps:
Collect transitions (s, a, r, s') using epsilon-greedy exploration
Store transitions in an experience replay buffer
Sample random minibatches from the buffer
Update Q-network using TD targets from a frozen target network
Periodically update target network by copying Q-network parameters
The algorithm minimizes the temporal difference error:
L = E[(Q(s,a) - (r + γ max_a' Q_target(s', a')))²]
Usage
Basic usage:
open Fehu
(* Create Q-network *)
let q_net = Kaun.Layer.sequential [
Kaun.Layer.linear ~in_features:4 ~out_features:64 ();
Kaun.Layer.relu ();
Kaun.Layer.linear ~in_features:64 ~out_features:2 ();
] in
(* Initialize agent *)
let agent = Dqn.create
~q_network:q_net
~n_actions:2
~rng:(Rune.Rng.key 42)
Dqn.{ default_config with batch_size = 64 }
in
(* Train *)
let agent = Dqn.learn agent ~env ~total_timesteps:100_000 () in
(* Use trained policy (greedy) *)
let action = Dqn.predict agent obs ~epsilon:0.0
Manual training loop:
(* Collect transition *)
let obs, _info = Env.reset env () in
let action = Dqn.predict agent obs ~epsilon:0.1 in
let transition = Env.step env action in
(* Store in buffer *)
Dqn.add_transition agent ~observation:obs ~action
~reward:transition.reward ~next_observation:transition.observation
~terminated:transition.terminated ~truncated:transition.truncated;
(* Update Q-network *)
let loss, avg_q = Dqn.update agent in
(* Periodically update target network *)
if episode mod 10 = 0 then Dqn.update_target_network agent
Key Features
Experience Replay: Breaks correlation between consecutive samples by randomly sampling from a replay buffer of past transitions.
Target Network: Uses a separate, periodically-updated network for computing TD targets, improving stability.
Epsilon-Greedy Exploration: Balances exploration and exploitation with decaying epsilon parameter.
Off-Policy: Can learn from any transitions, enabling experience reuse and data-efficient learning.
When to Use DQN
Discrete action spaces (e.g., game controls, navigation)
Environments where off-policy learning is beneficial
Tasks requiring sample efficiency through experience replay
Problems with deterministic or near-deterministic dynamics
For continuous action spaces, consider SAC or DDPG (coming soon).
Encapsulates Q-network, target network, experience replay buffer, optimizer state, and training configuration. The agent maintains all state needed for training and inference.
Note: Observations must be float32 tensors and actions are int32 tensors (discrete actions).
The Q-network should be a standard feedforward network. For image observations, use convolutional layers. For vector observations, use fully-connected layers.
predict agent obs ~epsilon selects an action using epsilon-greedy policy.
With probability epsilon, selects a random action (exploration). With probability 1 - epsilon, selects the action with highest Q-value (exploitation).
Parameters:
agent: DQN agent.
obs: Observation tensor of shape obs_dim or batch_size, obs_dim. Automatically handles batching if needed.
epsilon: Exploration rate in [0, 1]. Use 0.0 for fully greedy policy (no exploration), 1.0 for fully random policy.
Returns action as int32 scalar tensor.
Example:
(* During training with decaying exploration *)
let epsilon = compute_epsilon ~timesteps in
let action = Dqn.predict agent obs ~epsilon in
(* During evaluation (no exploration) *)
let action = Dqn.predict agent obs ~epsilon:0.0
truncated: Whether episode was artificially truncated (timeout).
Transitions are stored in a circular buffer. When the buffer is full, oldest transitions are overwritten.
The distinction between terminated and truncated matters for bootstrapping: terminal states have value 0, while truncated states may have non-zero value.
update agent performs a single gradient update on the Q-network.
Samples a batch from the replay buffer, computes TD targets using the target network, and updates the Q-network parameters using gradient descent on the TD error.
Returns (loss, avg_q_value) where:
loss: Mean squared TD error for the batch
avg_q_value: Average Q-value predicted for the batch
If the replay buffer has fewer samples than batch_size, returns (0.0, 0.0) without performing an update.
The TD target is computed as:
y = r + γ max_a' Q_target(s', a')
where Q_target is the frozen target network.
Call this function after each environment step during training.
update_target_network agent updates the target network by copying Q-network parameters.
Should be called periodically (every config.target_update_freq episodes) to keep the target network stable. Frequent updates can cause instability, while infrequent updates can slow learning.
Example:
if episode mod agent.config.target_update_freq = 0 then
Dqn.update_target_network agent
learn agent ~env ~total_timesteps ~callback ~warmup_steps () trains the DQN agent on an environment.
Runs episodes until total_timesteps is reached. Each episode: 1. Resets environment 2. Collects transitions using epsilon-greedy policy 3. Stores transitions in replay buffer 4. Samples batches and updates Q-network 5. Periodically updates target network
Parameters:
agent: DQN agent to train.
env: Environment to train on.
total_timesteps: Total number of environment steps to train for.
callback: Optional callback called after each episode with episode number and metrics. Return false to stop training early. Default always returns true.
warmup_steps: Number of initial steps to collect before starting training (filling replay buffer). Default: batch_size.
Returns the trained agent.
Example with callback:
let agent =
Dqn.learn agent ~env ~total_timesteps:100_000
~callback:(fun ~episode ~metrics ->
if episode mod 10 = 0 then
Printf.printf "Episode %d: Return = %.2f, Loss = %.4f\n" episode
metrics.episode_return metrics.loss;
true (* continue training *))
()
The warmup phase collects random experiences before training begins. This ensures the replay buffer has diverse samples before Q-network updates start.