第06回 注意力机制——让模型学会聚焦
满堂灯火谁为主,一束微光照要津。
万言纷纷皆可弃,只取关键定乾坤。
上回我们打完情感分析的擂台,明白了一个现实:
句子长起来,信息就像集市——吆喝声太多,你得知道该听谁。
RNN/LSTM 的“记事”,像一边走路一边做笔记:能记,但笔记越厚越难翻,越长越费劲。
于是江湖里出现了一门新功夫:注意力(Attention)。
它不求把整条街都背下来,而是每次提问时,都能把灯光打在“此刻最相关的那几句”上。
一、注意力的直觉:三件物什——问、钥、账
注意力最常见的说法是 Q、K、V:
- Query(Q):你此刻想问什么
- Key(K):每条信息的“索引牌”
- Value(V):每条信息的“内容账本”
你可以把它想成“查卷宗”:
- 你带着一个问题(Q)进档案库
- 你看每份卷宗门牌(K),判断像不像你要找的
- 你按相似度给每份卷宗一个权重
- 最后把卷宗内容(V)按权重加权求和,得到一份“聚焦后的摘要”
这份摘要,就是注意力的输出。
二、最核心的一步:用点积算“像不像”
上回第01回我们见过点积与余弦相似度。注意力这里也用相似度,只是更“工程化”:
对某个位置的 Query 向量 ,对所有 Key 向量 ,计算相似度分数:
分数越大,表示“越相关”。
接着用 softmax 把分数变成权重(让权重都在 0 到 1 之间,并且加起来等于 1):
最后输出是 Value 的加权平均:
这就是注意力的核心骨法:
点积做打分,softmax 做归一,加权求和做聚焦。
为什么要“缩放”(scaled)?
因为维度大时点积数值会变大,softmax 容易极端化(像一不小心就把灯全打到一个点上)。常见做法是除以 ,其中 是向量维度:
你无需证明它,只记直觉:把分数调回合适的量级,训练更稳。
三、自注意力与交叉注意力:是“自省”还是“问外人”
注意力有两种常见用法:
- 自注意力(self-attention):Q、K、V 都来自同一段序列(自己看自己)
- 交叉注意力(cross-attention):Q 来自一段序列,K/V 来自另一段序列(问外部记忆)
你可把它与后面要学的系统一对照:
- Transformer 编码器里大量用自注意力
- 翻译/对话的编码器‑解码器结构里,解码器会用交叉注意力去“看编码器输出”
- RAG 与智能体里,“问外部文档/工具结果”,在抽象层面也像交叉注意力:问题是 Q,证据是 K/V
四、因果遮罩:让模型别偷看答案
当我们要做“从左到右”的生成(像写作文),就不能让当前位置偷看未来的词。
于是注意力会加一个“遮罩”(mask):
- 未来位置的分数直接设为极小值
- softmax 后权重几乎为 0
这就是“因果注意力”(causal attention)的关键规矩:
只能看过去,不能看未来。
这一点会在第08回(自回归预训练)里成为主角。
五、多头注意力:一盏灯不够,就点八盏
有时一句话里,“相关性”不止一种:
- 有的关系是“指代”(他/她/它指谁)
- 有的关系是“否定”(不、没、别)
- 有的关系是“搭配”(很 + 好吃)
多头注意力(multi-head attention)就像:
把同一个输入,投影到多个不同“视角”,每个视角都做一遍注意力,再把结果拼起来。
你可以把它理解为:
让模型同时从多种角度“看重点”。
六、代价与现实:为什么注意力会成为瓶颈
注意力厉害,但也贵。
当序列长度是 时,注意力要算一个 的打分表(每个位置看每个位置),计算与内存都容易变成瓶颈。
这正是后面“后 Transformer 时代”(第五篇第41回)要对付的老敌人:
长上下文带来的二次方成本。
因此 2024 年仍有大量工作围绕“如何把注意力算得更快、更稳、更省”在下功夫。比如 FlashAttention-3 就针对新硬件特性,把注意力实现做得更高效,并讨论低精度下的数值误差控制。1
注意:本书在第06回只讲“武功招式的原理”,不讲 GPU 内功细节;但你要知道江湖里为什么会有这些优化:
因为注意力太常用、太关键、也太费钱。
七、极简代码:手写一次缩放点积注意力(PyTorch 可跑)
下面用最小代码手写一次单头注意力,展示每一步长什么样。
import math
import torch
import torch.nn.functional as F
def scaled_dot_attention(q, k, v, mask=None):
# q: [B, Tq, D], k/v: [B, Tk, D]
d = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d) # [B, Tq, Tk]
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
w = F.softmax(scores, dim=-1) # [B, Tq, Tk]
out = torch.matmul(w, v) # [B, Tq, D]
return out, w
if __name__ == "__main__":
torch.manual_seed(0)
B, T, D = 1, 4, 8
x = torch.randn(B, T, D)
# self-attention: q=k=v=x
out, w = scaled_dot_attention(x, x, x)
print("weights shape:", w.shape)
print("weights row0:", [round(v, 3) for v in w[0, 0].tolist()])
# causal mask: only attend to <= current position
causal = torch.tril(torch.ones(T, T)).unsqueeze(0).unsqueeze(0) # [1,1,T,T]
out2, w2 = scaled_dot_attention(x, x, x, mask=causal[0])
print("causal row3:", [round(v, 3) for v in w2[0, 3].tolist()])
运行后你会看到:
- 没遮罩时,每个位置可以把灯照向任意位置
- 有因果遮罩后,第 4 个位置只能照向前 4 个(不能照未来)
这就是“聚焦”的可计算版本。
八、小结:本回学会“聚焦”,下回要见“整座城”
本回你已经掌握注意力的三件法宝:
- Q/K/V:问、钥、账
- 点积 + softmax:打分并归一
- mask:守规矩,不偷看
下一回要登场的,是把这些招式组装成一座完整城池的架构:Transformer。
届时你会看到:
- 注意力如何与残差、归一化、前馈网络拼成一层
- 多层堆叠如何让模型越练越深
- 为什么它会成为大模型的骨架
欲知后事如何,且听下回分解。
引用与溯源
Footnotes
-
Shah, J., et al. FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision arXiv:2407.08608(2024-07)https://arxiv.org/abs/2407.08608 ↩