第17回 先学人话再学脾气——SFT 与 RLHF 全流水线
先把口条学周正,方能上阵论输赢。
尺子一立分高下,脾气再调近人情。
上回我们讲 PPO 的魂:步子别迈大,要有刹车。
可看官要问:PPO 是“怎么改策略”,那策略凭什么要改、往哪改?
若没有方向,刹车踩得再稳,也只是原地打转。
所以这一回,我们把 RLHF 的全流水线摆上桌:
从“先学人话”(SFT)到“再学脾气”(RLHF)。
你只要把这条流水线看明白,后面 DPO/IPO/KTO(第18回)才看得出它们究竟是在“替代哪一段”。
一、先说人话:SFT 是“学徒跟师傅抄作业”
SFT(Supervised Fine-Tuning)是什么?
一句话:用高质量示范数据,把模型先拉到“像样”的起跑线。
你可以把它想成:
- 师傅给出标准答法(示范回答)
- 学徒照着抄,抄到像样为止
数学上就是普通的监督学习:让模型在给定提示词 时,尽量生成示范答案 。
其好处非常朴素:
- 让模型先会基本礼貌、基本格式
- 让模型先别满嘴跑火车
- 让后续的“脾气微调”有一个稳的底盘
在 RLHF 的经典实践里,SFT 是明确的一步。1
二、再立尺子:奖励模型 RM 是“把偏好变成分数”
人类偏好往往不是“对/错”这么硬:
有时两条回答都对,但一个更清楚、更安全、更不绕。
要让机器学这种偏好,常用的方法是:
- 给同一个提示词 ,拿到两条候选回答 与
- 人类(或偏好数据集)选出更喜欢的那条
然后训练一个奖励模型 ,让它满足:
最常见的训练方式,是用一个“胜负概率”去拟合人类选择:
把“更喜欢谁”建模成逻辑回归/Softmax 的形式:
这一步非常像“打分老师”上岗:
他不直接教你怎么写,而是给你每次写的东西打分。
你会在很多 RLHF 论文里看到这种“成对偏好 + 奖励模型”的结构。2 1
三、最后调脾气:用 PPO 把“想要的风格”写进策略
现在方向有了:奖励模型就是方向盘。
接下来要做的是:让策略模型生成的回答,在奖励模型眼里更高分。
但直接把策略往奖励高分的方向推,会有两个工程灾难:
- 奖励黑客:模型学会钻奖励模型的漏洞
- 语言崩坏:模型为了分数,把原本的“人话能力”弄丢
所以 RLHF 的 PPO 往往还会加一条“别跑太远”的约束:
让新策略别偏离参考策略 太多(参考策略通常就是 SFT 后的模型)。
直觉就是:
你可以变乖,但别变成另一个怪物。
这与第16回“步子别迈大”的精神完全一致:
PPO 的 clip 管的是“单步更新幅度”,KL 约束/惩罚管的是“总体风格别漂移”。
四、你真正需要记住的 RLHF:三件事、两份模型、一条绳
把流水线压缩成最小记忆包:
- 三件事:SFT(学人话)、RM(立尺子)、RL(调脾气)
- 两份模型:策略模型(要被调)、奖励模型(负责打分)
- 一条绳:别离参考策略太远(KL 约束/惩罚的精神)
你要是把这包揣兜里,后面看到各种变体都不慌:
它们无非是改了“尺子怎么立”“脾气怎么调”“绳子怎么拴”。
五、极简可跑代码:玩具版 RLHF(SFT→RM→PPO 更新一小步)
下面这段代码不做大模型,只做一个最小“二选一回答”的玩具版 RLHF:
- 每个提示词只有两个回答可选:回答0 或 回答1
- “人类偏好”由隐藏规则决定(我们用概率模拟)
- 先做 SFT:把策略往偏好答案靠
- 再训 RM:用成对偏好训练一个打分器
- 最后做一轮 PPO 式更新:用 RM 的分数当奖励,带 KL 惩罚靠住参考策略
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_prompts(n=64):
return torch.arange(n, dtype=torch.long)
def hidden_preference(prompt_id):
return 1 if (prompt_id % 2 == 0) else 0
def sample_pair(prompt_id, rnd):
good = hidden_preference(prompt_id)
bad = 1 - good
if rnd.random() < 0.5:
return good, bad
return bad, good
class Policy(nn.Module):
def __init__(self, n_prompt, n_action=2, d=16):
super().__init__()
self.emb = nn.Embedding(n_prompt, d)
self.head = nn.Linear(d, n_action)
def logits(self, prompts):
return self.head(self.emb(prompts))
def dist(self, prompts):
return torch.distributions.Categorical(logits=self.logits(prompts))
class RewardModel(nn.Module):
def __init__(self, n_prompt, n_action=2, d=16):
super().__init__()
self.emb = nn.Embedding(n_prompt, d)
self.head = nn.Linear(d + n_action, 1)
self.n_action = n_action
def score(self, prompts, actions):
a = F.one_hot(actions, num_classes=self.n_action).float()
x = torch.cat([self.emb(prompts), a], dim=-1)
return self.head(x).squeeze(-1)
def kl_categorical(logits_p, logits_q):
p = F.log_softmax(logits_p, dim=-1)
q = F.log_softmax(logits_q, dim=-1)
pp = p.exp()
return (pp * (p - q)).sum(dim=-1)
if __name__ == "__main__":
torch.manual_seed(0)
rnd = random.Random(0)
n_prompt = 64
prompts = make_prompts(n_prompt)
pi = Policy(n_prompt)
pi_ref = Policy(n_prompt)
pi_ref.load_state_dict(pi.state_dict())
rm = RewardModel(n_prompt)
opt_sft = torch.optim.Adam(pi.parameters(), lr=1e-2)
opt_rm = torch.optim.Adam(rm.parameters(), lr=1e-2)
opt_ppo = torch.optim.Adam(pi.parameters(), lr=1e-3)
for _ in range(200):
idx = torch.randint(0, n_prompt, (32,))
good = torch.tensor([hidden_preference(int(i)) for i in idx], dtype=torch.long)
loss = F.cross_entropy(pi.logits(idx), good)
opt_sft.zero_grad()
loss.backward()
opt_sft.step()
for _ in range(400):
idx = torch.randint(0, n_prompt, (64,))
a1 = []
a2 = []
y = []
for i in idx.tolist():
x1, x2 = sample_pair(i, rnd)
pref = 1 if x1 == hidden_preference(i) else 0
a1.append(x1)
a2.append(x2)
y.append(pref)
a1 = torch.tensor(a1, dtype=torch.long)
a2 = torch.tensor(a2, dtype=torch.long)
y = torch.tensor(y, dtype=torch.float32)
s1 = rm.score(idx, a1)
s2 = rm.score(idx, a2)
p = torch.sigmoid(s1 - s2)
loss = F.binary_cross_entropy(p, y)
opt_rm.zero_grad()
loss.backward()
opt_rm.step()
batch = torch.randint(0, n_prompt, (128,))
dist = pi.dist(batch)
a = dist.sample()
logp = dist.log_prob(a)
with torch.no_grad():
r = rm.score(batch, a)
logits_now = pi.logits(batch)
logits_ref = pi_ref.logits(batch)
kl = kl_categorical(logits_now, logits_ref)
reward = r - 0.1 * kl
adv = reward - reward.mean()
ratio = torch.exp(logp - logp.detach())
clip_eps = 0.2
obj1 = ratio * adv
obj2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * adv
loss = -(torch.min(obj1, obj2)).mean()
opt_ppo.zero_grad()
loss.backward()
opt_ppo.step()
with torch.no_grad():
p1 = torch.softmax(pi.logits(prompts), dim=-1)[:, 1].mean().item()
pref_rate = 0.0
for i in range(n_prompt):
a = int(torch.argmax(pi.logits(torch.tensor([i]))).item())
pref_rate += 1.0 if a == hidden_preference(i) else 0.0
pref_rate /= n_prompt
print("mean P(action=1):", round(p1, 3))
print("greedy preference rate:", round(pref_rate, 3))
你跑完会看到一个直观结果:
策略的“贪心选择”更接近隐藏偏好(因为 SFT 和后续更新都在往偏好方向推)。
这段玩具代码真正要你体会的只有三点:
- SFT 给了一个“像样的底盘”
- RM 把偏好变成可计算的分数
- PPO/RL 把分数变成策略的性格改变,同时用约束防止跑飞
六、2024–2026 的现实提醒:RLHF 成败常败在“细节”
看官别以为知道三步就能复刻工业 RLHF。
真实训练里,“细节”往往比“公式”更像刀口:
- 采样与过滤:生成什么样的候选,决定你收集到什么样的偏好信号
- KL 控制与奖励归一:决定你是在稳步变好,还是奖励黑客
- 批次组织与长度处理:决定训练稳定性与吞吐
2024 的复现实证工作强调了 RLHF + PPO 里许多关键实现细节,并展示它们对稳定性的影响。3
同年也出现把 RLHF 做得更异步、更高效的训练范式讨论,反映“规模上去以后,流水线必须工程化”。4
如果你把 RLHF 看作“炼丹”,那本回讲的是丹方结构;
而这些 2024 工作讲的是火候与药性。
七、小结:本回把整条流水线打通
收束成五句话:
- SFT 先学人话:把模型拉到可用起跑线
- RM 立尺子:把偏好变成分数
- RL(常用 PPO)调脾气:让策略更高分
- 参考策略与 KL 是绳:防止风格漂移与奖励黑客
- 工程细节决定成败:公式懂了不等于跑得稳
下一回(第18回)我们就讲一类“替代方案”:
DPO/IPO/KTO 等方法如何绕开 PPO 的一部分复杂性,把“偏好”更直接地写进模型更新里。
欲知后事如何,且听下回分解。
幻觉核查
- TL;DR 人类反馈摘要工作核对:arXiv:2009.01325 可核验作者、年份与“用人类偏好训练摘要”的主张。2
- InstructGPT(SFT→RM→PPO)核对:arXiv:2203.02155 可核验 RLHF 流水线结构与实验结论。1
- 2024 RLHF+PPO 细节复现核对:arXiv:2403.17031 可核验“实现细节枚举与复现”的定位。3
- 2024 异步 RLHF 核对:arXiv:2410.18252 可核验标题与“异步/off‑policy”方向表述。4
- 玩具代码中“人类偏好”由隐藏规则模拟,只用于演示流水线结构,不代表真实偏好分布。
逻辑审计
- 与第16回承接:第16回讲 PPO 怎么稳更新,本回解释 PPO 在 RLHF 里“为什么更新、更新什么”。
- 与第18回衔接:本回明确 RLHF 三步与约束角色,下回可逐段对比 DPO 系列替代的是哪一环。
- 难度控制:偏好建模只用 sigmoid 与“胜负概率”,不引入信息几何与复杂推导。
引用与溯源
Footnotes
-
Ouyang, L., et al. Training language models to follow instructions with human feedback arXiv:2203.02155 (2022) https://arxiv.org/abs/2203.02155 ↩ ↩2 ↩3 ↩4
-
Stiennon, N., et al. Learning to summarize with human feedback arXiv:2009.01325 (2020) https://arxiv.org/abs/2009.01325 ↩ ↩2 ↩3
-
Huang, S., et al. The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization arXiv:2403.17031 (2024-03) https://arxiv.org/abs/2403.17031 ↩ ↩2
-
Noukhovitch, M., et al. Asynchronous RLHF: Faster and More Efficient Off-Policy RL for Language Models arXiv:2410.18252 (2024-10) https://arxiv.org/abs/2410.18252 ↩ ↩2