InstructGPT,Chain of Thought,Llama 3.1论文笔记
1. InstructGPT: Training language models to follow instructions with human feedback
核心思想:通过人类反馈强化学习,解决预训练模型“不听话”或输出有害内容的问题。它让模型从单纯的“概率预测”转变为“意图对齐”。
技术方案:
SFT(有监督微调):在精选的人工指令数据集上微调 GPT-3。
RM(奖励模型训练):收集人对模型输出的排序,训练一个打分模型。
PPO(强化学习优化):利用 RM 的分数通过 PPO 算法不断迭代更新模型参数。
技术目标:提升模型的 3H 属性:Helpfulness(有用性)、Honesty(诚实性)、Harmlessness(无害性)。
实验结果:1.3B 参数的 InstructGPT 模型在人类偏好测试中击败了 175B 的 GPT-3。
瓶颈分析:
标注成本与偏差:人类评估员的品味决定了模型的上限,存在主观偏见。
奖励对齐税:过度对齐会导致模型在某些原始基准测试(如代码或数学)上的性能下降。
领域影响:它确立了现代大模型“预训练+对齐”的两阶段标准流程。
2. Chain of Thought Prompting Elicits Reasoning in Large Language Models
核心思想:大模型的逻辑推理能力是“涌现”出来的,可以通过在提示词中展示中间推理步骤来显著激活。
技术方案:
Few-shot CoT:在 Prompt 中给出几个包含“问题 -> 思考过程 -> 答案”的示例。
Zero-shot CoT:直接在提问后加上 "Let's think step by step"。
技术目标:攻克大模型在多步算术推理、符号处理和常识推理任务上的软肋。
实验结果:在 PaLM 540B 上,CoT 将数学推理任务(GSM8K)的准确率从 17.9% 提升到了 56.9%。
瓶颈分析:
规模依赖:CoT 效应在小模型(<10B)上几乎不生效,甚至会降低性能。
虚假推理:模型可能给出了正确的步骤,但答案是错的,或者结论正确但步骤是胡编的(幻觉问题)。
领域影响:开创了 提示工程的先河,并催生了 自一致性等后续增强技术。
推理层:复现 CoT 的 Self-Consistency (自一致性)
既然 CoT 的瓶颈是逻辑幻觉,OpenAI 和 Llama 常用 Self-Consistency 来解决:让模型生成多次推理路径,取出现次数最多的答案。
import openai
def solve_with_cot_consistency(question, n_samples=3):
responses = []
prompt = f"Question: {question}\nAnswer: Let's think step by step."
for _ in range(n_samples):
# 模拟调用模型生成多个推理分支
# 注意:设置 temperature > 0 以获得多样性
out = client.chat.completions.create(
model="llama-3.1-8b",
messages=[{"role": "user", "content": prompt}],
temperature=0.7
)
responses.append(out.choices[0].message.content)
# 逻辑处理:提取各回复中的最后答案(通常是最后一行或数字)
# 然后进行投票(Voting)
return responses
3. Llama 3.1: The Llama 3 Herd of Models
核心思想:证明了在大规模高质量数据(15T Tokens)和精细对齐策略下,标准 Transformer 架构依然具有强大的生命力,并能达到顶级闭源模型水平。
技术方案:
采用 GQA(分组查询注意力)和 RoPE(旋转位置编码),支持 128K 长度上下文。
后训练:结合了多轮 SFT、拒绝采样和 DPO(直接偏好优化)。
合成数据:利用模型自己生成高质量数据来辅助训练。
技术目标:构建一个性能对标 GPT-4o 的开源模型生态,提升多语言能力和工具调用能力。
实验结果:405B 版本在 150 多个基准测试中与 GPT-4o 互有胜负。
瓶颈分析:
推理成本:405B 的 Dense 架构对硬件要求极高(即便 4-bit 量化也需要多张 H100)。
知识截止:尽管训练数据庞大,但对于实时性极强的知识仍需依赖 RAG(检索增强)。
领域影响:它打破了“闭源模型绝对领先”的神话,通过 DPO 技术简化了 InstructGPT 中复杂的 RLHF 过程,成为目前开源界的“工业标准”。
架构层:复现 Llama 3.1 的核心——RoPE (旋转位置编码)
Llama 3.1 放弃了传统的绝对位置编码,改用 RoPE。这是模型能够扩展到 128K 上下文的关键。
import torch
import torch.nn as nn
def precompute_theta_pos_frequencies(dim: int, seq_len: int, theta: float = 10000.0):
# 计算旋转角度: theta_i = 10000^(-2(i-1)/d)
# dim 必须是偶数
theta_numerator = torch.arange(0, dim, 2).float()
theta_freqs = 1.0 / (theta ** (theta_numerator / dim))
# 生成位置索引 [0, 1, ..., seq_len-1]
m = torch.arange(seq_len)
# 计算外积得到所有位置的频率
freqs = torch.outer(m, theta_freqs) # Shape: (seq_len, dim/2)
# 转换为复数极坐标形式,方便旋转计算
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
return freqs_complex
def apply_rotary_emb(x: torch.Tensor, freqs_complex: torch.Tensor):
# x shape: (batch, seq_len, head, head_dim)
# 将输入转为复数形式
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
# 广播频率并进行旋转计算 (复数乘法即旋转)
x_rotated = x_complex * freqs_complex.unsqueeze(0).unsqueeze(2)
# 转回实数
x_out = torch.view_as_real(x_rotated).flatten(3)
return x_out.type_as(x)
把上述函数用在例子中:
import torch
def precompute_theta_pos_frequencies(dim: int, seq_len: int, theta: float = 10000.0):
theta_numerator = torch.arange(0, dim, 2).float()
theta_freqs = 1.0 / (theta ** (theta_numerator / dim))
m = torch.arange(seq_len)
freqs = torch.outer(m, theta_freqs)
freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
return freqs_complex
def apply_rotary_emb(x: torch.Tensor, freqs_complex: torch.Tensor):
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * freqs_complex.unsqueeze(0).unsqueeze(2)
x_out = torch.view_as_real(x_rotated).flatten(3)
return x_out.type_as(x)
# 模拟一个 Batch 的数据:batch=2, seq_len=128, heads=8, head_dim=64
batch, seq_len, heads, head_dim = 2, 128, 8, 64
# 定义 x (随机生成一些数据)
x = torch.randn(batch, seq_len, heads, head_dim)
# 计算频率矩阵
freqs_complex = precompute_theta_pos_frequencies(head_dim, seq_len)
# 3. 现在运行
print("输入 x 的形状:", x.shape)
output = apply_rotary_emb(x, freqs_complex)
print("RoPE 处理后输出的形状:", output.shape)
# 验证形状是否一致 (RoPE 不改变 Tensor 形状)
assert x.shape == output.shape
print("✅ 形状验证通过!RoPE 成功注入了位置信息。")

对齐层:极简版 DPO 数据构造逻辑
Llama 3.1 使用 DPO 代替了 InstructGPT 复杂的 PPO。DPO 的核心是准备 (Prompt, Chosen, Rejected) 三元组。
# 这是一个典型的 DPO 训练数据集样本格式
dpo_dataset_sample = {
"instruction": "请解释什么是堆栈溢出。",
"context": "在计算机科学课程中(如 CSAPP)...",
"chosen": "堆栈溢出是指程序向栈中写入的数据超过了分配的内存空间,导致覆盖了返回地址...", # 更好、更对齐的回答
"rejected": "就是电脑内存不够了,程序崩溃了。" # 敷衍或错误的回答
}
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)