现代循环神经网络(三):编码器-解码器、Seq2Seq与束搜索
前两篇我们解决了两件事:
- 上篇:让模型学会“记住和忘记”(GRU/LSTM)。
- 中篇:让模型“读得更深、更全”(深度与双向RNN)。
这一篇进入序列生成的核心场景:输入和输出都是序列,长度还不一定相同。机器翻译就是最典型的任务。
你可以把这篇当作一条完整实战链路:
- 数据怎么准备。
- 编码器-解码器怎么协作。
- Seq2Seq 训练和推理为什么不一样。
- 解码时为什么要束搜索,以及怎么选束宽。
一、机器翻译任务与数据准备
机器翻译属于 Seq2Seq(sequence-to-sequence)任务:
- 输入:源语言序列(例如英文句子)。
- 输出:目标语言序列(例如中文句子)。
它和分类任务最大的不同是:
- 分类输出一个标签。
- 翻译要按时间步逐词生成整个句子。
1. 数据预处理:先把语料变“可学”
真实语料常见问题:大小写混乱、空格异常、标点不统一、编码噪声。常见预处理步骤:
- 文本规范化(大小写、全角半角、标点形式)。
- 去除脏样本(空行、异常超长句)。
- 统一切词规则(词或子词)。
这些步骤的价值是:减少无效噪声,让模型参数集中学习语言映射规律。
2. 词元化与词表:文本到数字的关键桥梁
神经网络只能处理数字,所以必须把文本映射成 token ID。
常见策略:
- 英文:按词或子词切分(BPE/WordPiece 等)。
- 中文:按字、词或子词切分。
然后构建词表并加入特殊 token:
<bos>:目标序列开始。<eos>:目标序列结束。<pad>:补齐对齐。<unk>:词表外词元。
可以用这张表快速理解特殊 token 的作用:
| Token | 作用 | 若缺失会怎样 |
|---|---|---|
<bos> |
告诉解码器“开始生成” | 解码初始输入不明确 |
<eos> |
告诉解码器“停止生成” | 可能生成过长或不终止 |
<pad> |
让 batch 内样本对齐 | 无法批量训练 |
<unk> |
容纳未登录词 | OOV 词直接报错或丢失 |
3. 变长序列批处理:掩码是必需品
句子长度不一致时,常见做法:
- 按最大长度截断。
- 不足长度用
<pad>补齐。 - 构建有效长度掩码(mask)。
如果没有 mask,损失函数会把 <pad> 位置也当成“要预测的词”,训练目标会被污染。
下面是一个简化的掩码示例:
import torch
# 1 表示有效token,0表示<pad>
mask = torch.tensor([
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 0, 0],
], dtype=torch.float32)
# 假设每个位置的交叉熵损失
token_loss = torch.tensor([
[0.8, 0.6, 0.9, 0.7, 1.2, 1.1],
[0.5, 0.7, 0.6, 1.0, 1.3, 1.4],
])
# 只统计有效token
loss = (token_loss * mask).sum() / mask.sum()
print("masked loss:", float(loss))
二、编码器-解码器:先理解,再表达
编码器-解码器(Encoder-Decoder)是 Seq2Seq 的基础范式。
你可以把它想成口译过程:
- 编码器:先完整听懂输入。
- 解码器:再按目标语言逐步说出来。
1. 编码器做什么
输入源序列 x 1 , x 2 , . . . , x T x_1, x_2, ..., x_T x1,x2,...,xT 后,编码器输出语义表示(可是一组隐状态或最终状态)。
核心目标:把源句信息压缩为可供解码器使用的“上下文表示”。
2. 解码器做什么
在时间步 t t t,解码器依赖:
- 上一步 token(训练时可能是真实词,推理时是模型预测词)。
- 上一步解码器状态。
- 编码器上下文。
对应概率分解:
P ( y 1 , … , y T ′ ) = ∏ t = 1 T ′ P ( y t ∣ y < t , context ) P(y_1, \dots, y_{T'}) = \prod_{t=1}^{T'} P(y_t \mid y_{<t}, \text{context}) P(y1,…,yT′)=t=1∏T′P(yt∣y<t,context)
这条式子可以看作 Seq2Seq 的核心定义。
3. 信息流示意图
4. 一个现实问题:固定向量会不会丢信息
早期 Seq2Seq 常把整句压成单个上下文向量,句子很长时容易信息瓶颈。
这也是后来注意力机制(Attention)大放异彩的重要背景。
三、Seq2Seq训练与推理:同一模型,两套输入规则
很多新手在这里最容易混淆:训练阶段和推理阶段不完全一致。
1. 训练阶段:教师强制(Teacher Forcing)
训练时,解码器第 t t t 步通常喂入真实的 y t − 1 g o l d y_{t-1}^{gold} yt−1gold,而不是模型上一时刻预测。
优点:
- 收敛更快。
- 梯度更稳定。
- 早期不容易“越错越远”。
对应损失(带掩码)可写为:
L = − ∑ t m t log P ( y t g o l d ∣ y < t , context ) \mathcal{L} = -\sum_t m_t \log P(y_t^{gold} \mid y_{<t}, \text{context}) L=−t∑mtlogP(ytgold∣y<t,context)
其中 m t m_t mt 是有效位置掩码。
2. 推理阶段:自回归生成
推理时没有真实目标序列,只能“自己喂自己”:
- 输入
<bos>。 - 预测下一个词。
- 把预测词作为下一步输入。
- 遇到
<eos>或达到最大长度时停止。
这就是自回归(autoregressive)解码。
3. 训练推理差异带来的问题:曝光偏差
因为训练时总看真实前词,推理时却看模型前词,会出现分布偏移(exposure bias)。
常见缓解思路:
- Scheduled Sampling(逐步混合真实词与预测词)。
- 更强的解码策略(如束搜索)。
Scheduled Sampling(调度采样)简介
Scheduled Sampling 是一种在训练时逐步把解码器输入从“总是使用真实前词”过渡到“按一定概率使用模型预测前词”的策略。核心思想是:在训练中引入与推理阶段类似的噪声,让模型在面对自己预测时也能保持鲁棒性,从而缓解曝光偏差(exposure bias)。
常见实现方式有两类:
- 线性/指数衰减:随着训练步数增加,按线性或指数规律降低使用真实前词的概率 p_{teacher}。
- 逆向概率或基于模型置信度的替换:根据模型对预测的置信度决定是否使用预测结果作为下一步输入。
示例伪代码(每个时间步以概率 p 使用真实前词):
for t in range(1, T):
if random.random() < p_teacher:
input_t = y_gold[:, t-1]
else:
input_t = model_pred[:, t-1].argmax(dim=-1)
output_t, state = decoder.step(input_t, state, context)
# 更新 p_teacher(伪代码)
p_teacher = schedule(step) # 例如 p_teacher = max(0, 1 - k * step)
注意事项:
- Scheduled Sampling 能缓解暴露偏差,但过早或过强地使用模型预测可能导致训练不稳定;通常需要谨慎设计衰减曲线。
- 对于有强制对齐或外部记忆机制的模型,Scheduled Sampling 的收益可能有限。
4. 代码示意:教师强制训练一步
import torch
import torch.nn.functional as F
# logits: (B, T, V) 模型输出
# tgt_out: (B, T) 目标词ID(右移后的监督序列)
# mask: (B, T) 有效位掩码,pad位为0
def masked_ce_loss(logits, tgt_out, mask):
B, T, V = logits.shape
# 展平后做逐token交叉熵
loss_per_token = F.cross_entropy(
logits.reshape(B * T, V),
tgt_out.reshape(B * T),
reduction='none'
).reshape(B, T)
# 仅统计有效位置
loss = (loss_per_token * mask).sum() / mask.sum().clamp_min(1.0)
return loss
这段代码的重点不是“炫技”,而是让你看清 mask 在训练里的位置。
5. 评估:BLEU 常用但不是唯一真相
BLEU 衡量 n-gram 重合程度,优点是可自动化批量比较;但它也有边界:
- 对同义改写不够友好。
- 不直接衡量事实正确性。
- 有时会出现“BLEU 还行但句子读起来生硬”。
实践中建议 BLEU + 人工抽样 + 任务指标联合看。
四、束搜索:解码质量的关键放大器
解码时每一步都要做选择,选择策略直接决定最终句子质量。
1. 贪心搜索 vs 束搜索
| 方法 | 每步策略 | 速度 | 全局质量 |
|---|---|---|---|
| 贪心搜索 | 只取当前最大概率词 | 快 | 容易局部最优 |
| 束搜索 | 保留前k条候选路径 | 中 | 通常更好 |
贪心快但短视;束搜索稍慢但更有全局视野。
2. 束搜索流程(新手版)
- 初始路径只有
<bos>。 - 每条路径扩展出若干候选词。
- 计算累计分数,保留 top-k 路径。
- 重复直到达到
<eos>或长度上限。 - 从完成路径中选分数最优结果。
流程图:
3. 代码示意:简化束搜索伪代码
# 仅展示逻辑,不依赖具体模型实现
def beam_search_step(beams, k, step_fn):
# beams: [(tokens, score, state), ...]
all_candidates = []
for tokens, score, state in beams:
# step_fn 返回当前路径下一步候选: [(next_id, logp, next_state), ...]
for next_id, logp, next_state in step_fn(tokens, state):
new_tokens = tokens + [next_id]
new_score = score + logp # 累加log概率
all_candidates.append((new_tokens, new_score, next_state))
# 按分数从大到小排序,保留top-k
all_candidates.sort(key=lambda x: x[1], reverse=True)
return all_candidates[:k]
# 长度归一化建议:
# - 原始累计 log 概率偏好短序列(因为 logp 累加会随长度变小),
# 常用修正为 score / (length ** alpha) 或 score / length,其中 alpha 在 0~1 之间调节强度;
# 例如 Google NMT 常用 alpha=0.7 的幂次归一化。
# - 另外可在候选分数上加入重复惩罚或覆盖惩罚来减少冗余生成。
4. 束宽怎么选更合理
| 束宽k | 典型现象 |
|---|---|
| 太小(1或2) | 接近贪心,质量提升有限 |
| 适中(3到8) | 常是效果和速度较平衡区间 |
| 太大(>10) | 计算成本上升,收益可能递减 |
实际建议:从 k = 3 , 5 , 8 k=3,5,8 k=3,5,8 做小网格验证,不同任务最优值可能不同。
5. 束搜索常见细节
- 长度偏置:长句 log 概率更容易累积更小,常用长度归一化修正。
- 重复生成:可加入重复惩罚或 coverage 约束。
- 结束策略:有的实现按“完成路径数达到阈值”提前停止。
五、统一小结
这一篇建议你记住七件事:
- 机器翻译是典型 Seq2Seq,输入输出都可变长。
- 数据处理里,词表和特殊 token 不是细节,而是训练可行性的前提。
- mask 对变长批处理至关重要,不做 mask 容易把训练目标带偏。
- 编码器-解码器的本质是“先理解,再生成”。
- 训练(教师强制)和推理(自回归)输入规则不同,理解这点非常关键。
- BLEU 有价值但有边界,评估要多指标结合。
- 束搜索是实践中提升生成质量的常用方法,束宽要按任务验证。
下一步你自然会进入注意力机制与 Transformer,它们在长依赖和并行计算上进一步升级了 Seq2Seq 框架。
本章阅读顺序
这一章我拆成三篇来写,保证更细、更适合新手循序渐进阅读:
- 上篇:现代循环神经网络:从GRU到Seq2Seq,学会更长更复杂的序列.md
- 中篇:现代循环神经网络(二):深度RNN与双向RNN,读得更深更全.md
- 下篇(本篇):编码器-解码器、Seq2Seq 与束搜索
注:本文为通俗改写与学习整理,思路参考《动手学深度学习》现代循环神经网络:https://zh.d2l.ai/chapter_recurrent-modern/index.html
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)