Deep Q Network (DQN)
DQN extends Q-Learning simply by replacing Q-table with a neural network, enabling it to handle high-dimensional and continuous state spaces, such as images in video games [Playing Atari with Deep Reinforcement Learning]. In brief, DQN still builds upon Q-learning, the difference is that it approximates the Q-function with neural network parameters.
DQN Introduction
DQN is an algorithm that approximates the optimal action-value function using deep neural networks parameterized by θ. By leveraging neural networks, DQN can handle complex, non-linear tasks that traditional linear models cannot solve. A classic example is the XOR gate problem, which cannot be addressed using a simple linear function. In contrast, neural networks are capable of modeling such non-linear relationships, making them well-suited for reinforcement learning tasks with high-dimensional input spaces.
So, how does DQN work? Below, we outline the key components that make up the Deep Q-Network algorithm:
- Q-Function Approximator: the neural network takes states and outputs Q values for all actions.
- Loss Function: minimize the mean-squared error between the predicted Q-value and the target Q-value, derived from the Bellman function:
$$L(θ) = 𝔼[(r + γ\max_{a'}Q(s',a';θ^{-} - Q(s,a;θ))^2]$$
- Experience Replay: store experience in buffer and sampled randomly to train the network, reducing correlations and improve sample efficiency.
- Target Network: a separate network with paramters $θ^{-}$ is used to compute the target Q-values, and this network is updated periodically to stabilize training.
- Exploration: the behavior policy used to selects actions such as epsilon-greedy, UCB and tompson sampling. Epsilon-greedy is usaully applied.
There are many variations of the original DQN. Below, we list some of the most commonly used extensions:
- Double DQN: addresses Q-value overestimation by using two networks.
- Dueling DQN: separates the Q-value into state value and action advantage.
- Noisy DQN: add learnable noise to the network parameters to improve exploration:
- Rainbow DQN: combines merits of those advanced DQN.
DQN Implementation
In this example, we will use AirRaid, one of the games from the Atari suite, to demonstrate how DQNs can be applied to real-world scenarios. We will also utilize Convolutional Neural Networks (CNNs), which are widely used for processing image data. This case study highlights the effectiveness of DQNs in handling complex environments through visual input.
import os
import time
import ale_py # This is the Atari environment and should be imported to avoid errors.
import random
import argparse
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
from collections import deque
from gymnasium.wrappers import AtariPreprocessing
from gymnasium.wrappers import FrameStackObservation as FrameStack
class DQN(nn.Module):
def __init__(self, input_shape, n_actions):
super(DQN, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def _get_conv_out(self, shape):
"""
Get the output size of the convolutional layers
Args:
shape: The shape of the input tensor (4, 84, 84)
Returns:
The output size of the convolutional layers (512)
"""
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
conv_out = self.conv(x).view(x.size()[0], -1)
return self.fc(conv_out)
class DQNAgent:
def __init__(self, state_shape, n_actions, device):
self.device = device
self.n_actions = n_actions
self.epsilon = 1.0
self.epsilon_min = 0.1
self.epsilon_decay = 0.995
self.gamma = 0.99
self.batch_size = 32
self.memory = deque(maxlen=10000)
self.policy_net = DQN(state_shape, n_actions).to(device)
self.target_net = DQN(state_shape, n_actions).to(device)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.optimizer = optim.Adam(self.policy_net.parameters())
self.criterion = nn.MSELoss()
def choose_action(self, state):
if random.random() < self.epsilon:
return random.randrange(self.n_actions)
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.policy_net(state)
return q_values.max(1)[1].item()
def memorize(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def train(self):
if len(self.memory) < self.batch_size:
return None
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.FloatTensor(states).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1))
next_q_values = self.target_net(next_states).max(1)[0].detach()
expected_q_values = rewards + (1 - dones) * self.gamma * next_q_values
loss = self.criterion(current_q_values.squeeze(), expected_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
return loss.item()
def update_target_network(self):
self.target_net.load_state_dict(self.policy_net.state_dict())
def train(args):
env = gym.make('ALE/AirRaid-v5')
env = AtariPreprocessing(env, frame_skip=1, grayscale_obs=True, scale_obs=True)
env = FrameStack(env, stack_size=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
state_shape = (4, 84, 84)
n_actions = env.action_space.n
agent = DQNAgent(state_shape, n_actions, device)
rewards = []
losses = []
episode_times = []
episodes = args.episodes
print(f"Starting DQN training for {episodes} episodes...")
print(f"State shape: {state_shape}, Actions: {n_actions}")
print("-" * 80)
# Create dqn directory if it doesn't exist
os.makedirs("dqn", exist_ok=True)
# Generate timestamp for this training session
session_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
print(f"Training session ID: {session_timestamp}")
print("-" * 80)
training_start_time = time.time()
for episode in range(episodes):
episode_start_time = time.time()
state, _ = env.reset()
total_reward = 0
done = False
episode_losses = []
while not done:
action = agent.choose_action(state)
next_state, reward, done, truncated, info = env.step(action)
agent.memorize(state, action, reward, next_state, done)
loss = agent.train()
if loss is not None:
episode_losses.append(loss)
state = next_state
total_reward += reward
if done:
agent.update_target_network()
rewards.append(total_reward)
avg_loss = np.mean(episode_losses) if episode_losses else 0
losses.append(avg_loss)
env.close()
return rewards, losses
if __name__ == "__main__":
args = parse_args()
if args.simulate:
simulate(args)
else:
rewards, losses = train(args)
plot_curves(rewards, losses)
DQN with AirRaid - CSY
We apply CNN and DQN to the Atari 2600 game AirRaid to demonstrate the effectiveness of the algorithm. Below, we show the performance of DQN after training for only 5 epochs.

On the gaming screen, we can observe the agent training to avoid attacks while simultaneously attempting to attack the enemy. The example code can be found on this repo.
Conclusion
In this post, we demonstrate the capabilities of DQN, particularly in handling environments with image-based inputs, such as Atari games. DQN is a foundational algorithm that merges Q-Learning with deep convolutional neural networks, allowing it to approximate action-value functions directly from high-dimensional state spaces like raw pixel data.
While DQN is limited to environments with discrete action spaces, it remains a powerful and effective algorithm when carefully fine-tuned for specific tasks. Compared to traditional Q-Learning, which operates on simple, tabular or low-dimensional numerical states, DQN typically requires significantly more computational resources and training time. This is especially true when processing complex visual inputs.
Looking forward, DQN continues to be a relevant and insightful approach, particularly in research and applications involving discrete control from rich sensory inputs. We will continue exploring the possibilities of DQN by applying it to a variety of interesting and challenging tasks in the near future.