第14回 深度 Q 网络——DQN 与经验回放

旧账不翻难算清,新招要靠旧招成。
回放一匣藏百局,抽丝剥茧见输赢。

上回我们立了 MDP 的账本与贝尔曼递推:
今天这一口 + 明天账本的折扣期望。
可账本写出来是一回事,真要把它算出来又是另一回事——江湖大、状态多,表格根本记不下。

于是这门武功走到这里,就要亮出一句话:
用神经网络去近似 Q。

这便是深度 Q 网络(DQN)的来历:把“表格 Q-learning”升级成“神经网络 Q-learning”。
它之所以成了深度强化学习的开山派之一,是因为它把两个最要命的不稳定源头,用两把钉子钉住:

  • 经验回放(Replay Buffer):打散相关性,让学习更像“从题库抽题”
  • 目标网络(Target Network):让目标别乱跑,训练不至于自我追逐

一、先把 Q-learning 的骨法说透

你在第13回见过最优性递推的影子。
Q-learning 的骨法也一句:

我先估每个动作的“前途分”(Q),然后永远朝着前途分最大的那个走。

在表格时代,Q 是一张表:每个 s,as,a 一格。
但现实里 ss 可能是图像、是文本、是长向量,表格会爆炸。

所以 DQN 把 Q 表换成一个函数:

Q(s,a)Qθ(s,a)Q(s,a) \approx Q_\theta(s,a)

这里 θ\theta 是神经网络参数。
你可以把它理解为:给我状态 ss,我一次性吐出所有动作的“前途分”。


二、DQN 的训练目标:把贝尔曼递推变成“拟合题”

贝尔曼最优性告诉我们:

真正的 QQ^* 满足:
当前这一口 + 折扣后的“下一步最优前途”。

DQN 把它写成一个监督学习式的“拟合题”:

  • 你从经验里拿一条记录 s,a,r,s,dones,a,r,s',done
  • 你算一个目标值:
y=r+γmaxaQθ(s,a)y = r + \gamma \max_{a'} Q_{\theta^-}(s',a')

这里 θ\theta^- 是目标网络参数(后面解释)。
然后让当前网络 Qθ(s,a)Q_\theta(s,a) 去逼近 yy,最常见就是最小化平方误差:

(Qθ(s,a)y)2\left(Q_\theta(s,a)-y\right)^2

你看懂这一条,就懂了 DQN 的全部:
它把“长远账”压缩成“下一步最优前途”,再用梯度下降去拟合。


三、为什么不稳定:自己当裁判,自己当运动员

若你直接用同一个网络既算 Qθ(s,a)Q_\theta(s,a) 又算目标里的 maxQθ(s,a)\max Q_\theta(s',a'),就会发生一种很糟的现象:

  • 你追着一个会动的目标跑
  • 目标的移动还由你自己决定

这就像你一边写答案,一边改标准答案,最后越改越乱。

因此 DQN 的第一把钉子是:目标网络


四、目标网络:标准答案慢一点改

目标网络做法极朴素:

  • 维护两份网络:在线网络 QθQ_\theta 与目标网络 QθQ_{\theta^-}
  • 训练时用目标网络算目标 yy
  • 每隔一段时间,把在线网络参数拷贝给目标网络(让标准答案“跳一次”,而不是每步都动)

你可以把它理解为:
让老师的参考答案隔一会儿才更新一次,学生才有可能学得稳定。


五、经验回放:别只背刚写过的那一页

DQN 的第二把钉子是:经验回放

如果你按时间顺序一条条训练:

  • 连续样本高度相关
  • 你刚学到的东西会立刻被下一段相关样本“带偏”
  • 学习就像在一条狭窄河道里打转

经验回放的做法是:

  • 把经历过的转移存进一个“回放匣子”
  • 每次训练随机抽一批出来

于是学习更像:

从题库里抽题做,而不是只做刚刚遇到的一页。

这能显著减少相关性带来的抖动,也提高样本利用率。
经验回放怎么抽、怎么保持多样性,到了 2024 也仍然有人在做改进与研究。1


六、极简代码:一个小格子世界里的 DQN(PyTorch 可跑)

下面写一个最小可跑的 DQN:

环境是 1 条线上的 5 个格子,起点在最左,目标在最右。动作只有两个:向左或向右。到终点给 +1,其余每步给一点点小惩罚,逼它别磨蹭。

import random
from collections import deque

import torch
import torch.nn as nn
import torch.nn.functional as F


class LineWorld:
    def __init__(self, n=5):
        self.n = n
        self.reset()

    def reset(self):
        self.pos = 0
        return self.pos

    def step(self, action):
        # 0: left, 1: right
        if action == 0:
            self.pos = max(0, self.pos - 1)
        else:
            self.pos = min(self.n - 1, self.pos + 1)

        done = (self.pos == self.n - 1)
        reward = 1.0 if done else -0.1
        return self.pos, reward, done


class QNet(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


# Hyperparameters
gamma = 0.9
lr = 0.001
epsilon = 0.1
buffer_size = 1000
batch_size = 32

env = LineWorld()
q_net = QNet(1, 2)
target_net = QNet(1, 2)
target_net.load_state_dict(q_net.state_dict())  # 初始同步
optimizer = torch.optim.Adam(q_net.parameters(), lr=lr)
replay_buffer = deque(maxlen=buffer_size)

for episode in range(500):
    state = env.reset()
    total_reward = 0
    while True:
        # Epsilon-greedy
        if random.random() < epsilon:
            action = random.randint(0, 1)
        else:
            state_tensor = torch.FloatTensor([state]).unsqueeze(0)
            action = q_net(state_tensor).argmax().item()

        next_state, reward, done = env.step(action)
        replay_buffer.append((state, action, reward, next_state, done))

        state = next_state
        total_reward += reward

        # Training step
        if len(replay_buffer) > batch_size:
            batch = random.sample(replay_buffer, batch_size)
            s_b, a_b, r_b, ns_b, d_b = zip(*batch)

            s_b = torch.FloatTensor(s_b).unsqueeze(1)
            a_b = torch.LongTensor(a_b).unsqueeze(1)
            r_b = torch.FloatTensor(r_b).unsqueeze(1)
            ns_b = torch.FloatTensor(ns_b).unsqueeze(1)
            d_b = torch.FloatTensor(d_b).unsqueeze(1)

            # Q(s,a)
            q_val = q_net(s_b).gather(1, a_b)

            # Target: r + gamma * max Q_target(s', a')
            with torch.no_grad():
                next_q = target_net(ns_b).max(1)[0].unsqueeze(1)
                target = r_b + gamma * next_q * (1 - d_b)

            loss = F.mse_loss(q_val, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if done:
            break

    # Sync target network every 20 episodes
    if episode % 20 == 0:
        target_net.load_state_dict(q_net.state_dict())
        print(f"Episode {episode}, Reward: {total_reward:.2f}")

# Final test
state = env.reset()
print("Testing trained agent:")
for _ in range(10):
    state_tensor = torch.FloatTensor([state]).unsqueeze(0)
    action = q_net(state_tensor).argmax().item()
    print(f"Pos: {state} -> Action: {'Left' if action == 0 else 'Right'}")
    state, _, done = env.step(action)
    if done:
        print("Reached Goal!")
        break

这段代码虽短,却五脏俱全。
你可以试着把 target_net 去掉,直接用 q_net 算目标,看看收敛会不会变慢甚至震荡。


下回预告

DQN 虽然强,但它只适合“离散动作”(比如按键、下棋)。
如果我要控制机器人手臂,动作是连续的角度(比如 30.5度),Q 表也好,Q 网络也好,怎么输出?

这就需要另一派绝学:策略梯度(Policy Gradient)
它是直接学“怎么做”,而不是学“哪样好”。

下回:策略梯度——REINFORCE 与基线。

Footnotes

  1. 参考论文:Prioritized Experience Replay (ICLR 2016) 以及近期的 Replay Ratio 研究。