第06回 注意力机制——让模型学会聚焦

满堂灯火谁为主,一束微光照要津。
万言纷纷皆可弃,只取关键定乾坤。

上回我们打完情感分析的擂台,明白了一个现实:
句子长起来,信息就像集市——吆喝声太多,你得知道该听谁。

RNN/LSTM 的“记事”,像一边走路一边做笔记:能记,但笔记越厚越难翻,越长越费劲。
于是江湖里出现了一门新功夫:注意力(Attention)

它不求把整条街都背下来,而是每次提问时,都能把灯光打在“此刻最相关的那几句”上。


一、注意力的直觉:三件物什——问、钥、账

注意力最常见的说法是 Q、K、V:

  • Query(Q):你此刻想问什么
  • Key(K):每条信息的“索引牌”
  • Value(V):每条信息的“内容账本”

你可以把它想成“查卷宗”:

  1. 你带着一个问题(Q)进档案库
  2. 你看每份卷宗门牌(K),判断像不像你要找的
  3. 你按相似度给每份卷宗一个权重
  4. 最后把卷宗内容(V)按权重加权求和,得到一份“聚焦后的摘要”

这份摘要,就是注意力的输出。


二、最核心的一步:用点积算“像不像”

上回第01回我们见过点积与余弦相似度。注意力这里也用相似度,只是更“工程化”:

对某个位置的 Query 向量 qq,对所有 Key 向量 k1,,knk_1,\dots,k_n,计算相似度分数:

si=qkis_i = q \cdot k_i

分数越大,表示“越相关”。

接着用 softmax 把分数变成权重(让权重都在 0 到 1 之间,并且加起来等于 1):

αi=softmax(s)i\alpha_i = \text{softmax}(s)_i

最后输出是 Value 的加权平均:

o=i=1nαivio = \sum_{i=1}^{n} \alpha_i v_i

这就是注意力的核心骨法:
点积做打分,softmax 做归一,加权求和做聚焦。

为什么要“缩放”(scaled)?
因为维度大时点积数值会变大,softmax 容易极端化(像一不小心就把灯全打到一个点上)。常见做法是除以 d\sqrt{d},其中 dd 是向量维度:

si=qkids_i = \frac{q \cdot k_i}{\sqrt{d}}

你无需证明它,只记直觉:把分数调回合适的量级,训练更稳。


三、自注意力与交叉注意力:是“自省”还是“问外人”

注意力有两种常见用法:

  • 自注意力(self-attention):Q、K、V 都来自同一段序列(自己看自己)
  • 交叉注意力(cross-attention):Q 来自一段序列,K/V 来自另一段序列(问外部记忆)

你可把它与后面要学的系统一对照:

  • Transformer 编码器里大量用自注意力
  • 翻译/对话的编码器‑解码器结构里,解码器会用交叉注意力去“看编码器输出”
  • RAG 与智能体里,“问外部文档/工具结果”,在抽象层面也像交叉注意力:问题是 Q,证据是 K/V

四、因果遮罩:让模型别偷看答案

当我们要做“从左到右”的生成(像写作文),就不能让当前位置偷看未来的词。
于是注意力会加一个“遮罩”(mask):

  • 未来位置的分数直接设为极小值
  • softmax 后权重几乎为 0

这就是“因果注意力”(causal attention)的关键规矩:
只能看过去,不能看未来。

这一点会在第08回(自回归预训练)里成为主角。


五、多头注意力:一盏灯不够,就点八盏

有时一句话里,“相关性”不止一种:

  • 有的关系是“指代”(他/她/它指谁)
  • 有的关系是“否定”(不、没、别)
  • 有的关系是“搭配”(很 + 好吃)

多头注意力(multi-head attention)就像:

把同一个输入,投影到多个不同“视角”,每个视角都做一遍注意力,再把结果拼起来。

你可以把它理解为:
让模型同时从多种角度“看重点”。


六、代价与现实:为什么注意力会成为瓶颈

注意力厉害,但也贵。

当序列长度是 nn 时,注意力要算一个 n×nn\times n 的打分表(每个位置看每个位置),计算与内存都容易变成瓶颈。

这正是后面“后 Transformer 时代”(第五篇第41回)要对付的老敌人:
长上下文带来的二次方成本。

因此 2024 年仍有大量工作围绕“如何把注意力算得更快、更稳、更省”在下功夫。比如 FlashAttention-3 就针对新硬件特性,把注意力实现做得更高效,并讨论低精度下的数值误差控制。1

注意:本书在第06回只讲“武功招式的原理”,不讲 GPU 内功细节;但你要知道江湖里为什么会有这些优化:
因为注意力太常用、太关键、也太费钱。


七、极简代码:手写一次缩放点积注意力(PyTorch 可跑)

下面用最小代码手写一次单头注意力,展示每一步长什么样。

import math

import torch
import torch.nn.functional as F


def scaled_dot_attention(q, k, v, mask=None):
    # q: [B, Tq, D], k/v: [B, Tk, D]
    d = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d)  # [B, Tq, Tk]
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    w = F.softmax(scores, dim=-1)  # [B, Tq, Tk]
    out = torch.matmul(w, v)       # [B, Tq, D]
    return out, w


if __name__ == "__main__":
    torch.manual_seed(0)
    B, T, D = 1, 4, 8
    x = torch.randn(B, T, D)

    # self-attention: q=k=v=x
    out, w = scaled_dot_attention(x, x, x)
    print("weights shape:", w.shape)
    print("weights row0:", [round(v, 3) for v in w[0, 0].tolist()])

    # causal mask: only attend to <= current position
    causal = torch.tril(torch.ones(T, T)).unsqueeze(0).unsqueeze(0)  # [1,1,T,T]
    out2, w2 = scaled_dot_attention(x, x, x, mask=causal[0])
    print("causal row3:", [round(v, 3) for v in w2[0, 3].tolist()])

运行后你会看到:

  • 没遮罩时,每个位置可以把灯照向任意位置
  • 有因果遮罩后,第 4 个位置只能照向前 4 个(不能照未来)

这就是“聚焦”的可计算版本。


八、小结:本回学会“聚焦”,下回要见“整座城”

本回你已经掌握注意力的三件法宝:

  • Q/K/V:问、钥、账
  • 点积 + softmax:打分并归一
  • mask:守规矩,不偷看

下一回要登场的,是把这些招式组装成一座完整城池的架构:Transformer
届时你会看到:

  • 注意力如何与残差、归一化、前馈网络拼成一层
  • 多层堆叠如何让模型越练越深
  • 为什么它会成为大模型的骨架

欲知后事如何,且听下回分解。


引用与溯源

Footnotes

  1. Shah, J., et al. FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision arXiv:2407.08608(2024-07)https://arxiv.org/abs/2407.08608