前两篇我们解决了两件事:

  • 上篇:让模型学会“记住和忘记”(GRU/LSTM)。
  • 中篇:让模型“读得更深、更全”(深度与双向RNN)。

这一篇进入序列生成的核心场景:输入和输出都是序列,长度还不一定相同。机器翻译就是最典型的任务。

你可以把这篇当作一条完整实战链路:

  1. 数据怎么准备。
  2. 编码器-解码器怎么协作。
  3. Seq2Seq 训练和推理为什么不一样。
  4. 解码时为什么要束搜索,以及怎么选束宽。

一、机器翻译任务与数据准备

机器翻译属于 Seq2Seq(sequence-to-sequence)任务:

  • 输入:源语言序列(例如英文句子)。
  • 输出:目标语言序列(例如中文句子)。

它和分类任务最大的不同是:

  • 分类输出一个标签。
  • 翻译要按时间步逐词生成整个句子。

1. 数据预处理:先把语料变“可学”

真实语料常见问题:大小写混乱、空格异常、标点不统一、编码噪声。常见预处理步骤:

  • 文本规范化(大小写、全角半角、标点形式)。
  • 去除脏样本(空行、异常超长句)。
  • 统一切词规则(词或子词)。

这些步骤的价值是:减少无效噪声,让模型参数集中学习语言映射规律。

2. 词元化与词表:文本到数字的关键桥梁

神经网络只能处理数字,所以必须把文本映射成 token ID。

常见策略:

  • 英文:按词或子词切分(BPE/WordPiece 等)。
  • 中文:按字、词或子词切分。

然后构建词表并加入特殊 token:

  • <bos>:目标序列开始。
  • <eos>:目标序列结束。
  • <pad>:补齐对齐。
  • <unk>:词表外词元。

可以用这张表快速理解特殊 token 的作用:

Token 作用 若缺失会怎样
<bos> 告诉解码器“开始生成” 解码初始输入不明确
<eos> 告诉解码器“停止生成” 可能生成过长或不终止
<pad> 让 batch 内样本对齐 无法批量训练
<unk> 容纳未登录词 OOV 词直接报错或丢失

3. 变长序列批处理:掩码是必需品

句子长度不一致时,常见做法:

  1. 按最大长度截断。
  2. 不足长度用 <pad> 补齐。
  3. 构建有效长度掩码(mask)。

如果没有 mask,损失函数会把 <pad> 位置也当成“要预测的词”,训练目标会被污染。

下面是一个简化的掩码示例:

import torch

# 1 表示有效token,0表示<pad>
mask = torch.tensor([
    [1, 1, 1, 1, 0, 0],
    [1, 1, 1, 0, 0, 0],
], dtype=torch.float32)

# 假设每个位置的交叉熵损失
token_loss = torch.tensor([
    [0.8, 0.6, 0.9, 0.7, 1.2, 1.1],
    [0.5, 0.7, 0.6, 1.0, 1.3, 1.4],
])

# 只统计有效token
loss = (token_loss * mask).sum() / mask.sum()
print("masked loss:", float(loss))

二、编码器-解码器:先理解,再表达

编码器-解码器(Encoder-Decoder)是 Seq2Seq 的基础范式。

你可以把它想成口译过程:

  • 编码器:先完整听懂输入。
  • 解码器:再按目标语言逐步说出来。

1. 编码器做什么

输入源序列 x 1 , x 2 , . . . , x T x_1, x_2, ..., x_T x1,x2,...,xT 后,编码器输出语义表示(可是一组隐状态或最终状态)。

核心目标:把源句信息压缩为可供解码器使用的“上下文表示”。

2. 解码器做什么

在时间步 t t t,解码器依赖:

  • 上一步 token(训练时可能是真实词,推理时是模型预测词)。
  • 上一步解码器状态。
  • 编码器上下文。

对应概率分解:

P ( y 1 , … , y T ′ ) = ∏ t = 1 T ′ P ( y t ∣ y < t , context ) P(y_1, \dots, y_{T'}) = \prod_{t=1}^{T'} P(y_t \mid y_{<t}, \text{context}) P(y1,,yT)=t=1TP(yty<t,context)

这条式子可以看作 Seq2Seq 的核心定义。

3. 信息流示意图

源语言序列

编码器

上下文表示

解码器

目标语言序列

4. 一个现实问题:固定向量会不会丢信息

早期 Seq2Seq 常把整句压成单个上下文向量,句子很长时容易信息瓶颈。
这也是后来注意力机制(Attention)大放异彩的重要背景。


三、Seq2Seq训练与推理:同一模型,两套输入规则

很多新手在这里最容易混淆:训练阶段和推理阶段不完全一致。

1. 训练阶段:教师强制(Teacher Forcing)

训练时,解码器第 t t t 步通常喂入真实的 y t − 1 g o l d y_{t-1}^{gold} yt1gold,而不是模型上一时刻预测。

优点:

  • 收敛更快。
  • 梯度更稳定。
  • 早期不容易“越错越远”。

对应损失(带掩码)可写为:

L = − ∑ t m t log ⁡ P ( y t g o l d ∣ y < t , context ) \mathcal{L} = -\sum_t m_t \log P(y_t^{gold} \mid y_{<t}, \text{context}) L=tmtlogP(ytgoldy<t,context)

其中 m t m_t mt 是有效位置掩码。

2. 推理阶段:自回归生成

推理时没有真实目标序列,只能“自己喂自己”:

  1. 输入 <bos>
  2. 预测下一个词。
  3. 把预测词作为下一步输入。
  4. 遇到 <eos> 或达到最大长度时停止。

这就是自回归(autoregressive)解码。

3. 训练推理差异带来的问题:曝光偏差

因为训练时总看真实前词,推理时却看模型前词,会出现分布偏移(exposure bias)。

常见缓解思路:

  • Scheduled Sampling(逐步混合真实词与预测词)。
  • 更强的解码策略(如束搜索)。

Scheduled Sampling(调度采样)简介

Scheduled Sampling 是一种在训练时逐步把解码器输入从“总是使用真实前词”过渡到“按一定概率使用模型预测前词”的策略。核心思想是:在训练中引入与推理阶段类似的噪声,让模型在面对自己预测时也能保持鲁棒性,从而缓解曝光偏差(exposure bias)。

常见实现方式有两类:

  • 线性/指数衰减:随着训练步数增加,按线性或指数规律降低使用真实前词的概率 p_{teacher}。
  • 逆向概率或基于模型置信度的替换:根据模型对预测的置信度决定是否使用预测结果作为下一步输入。

示例伪代码(每个时间步以概率 p 使用真实前词):

for t in range(1, T):
    if random.random() < p_teacher:
        input_t = y_gold[:, t-1]
    else:
        input_t = model_pred[:, t-1].argmax(dim=-1)
    output_t, state = decoder.step(input_t, state, context)

    # 更新 p_teacher(伪代码)
    p_teacher = schedule(step)  # 例如 p_teacher = max(0, 1 - k * step)

注意事项:

  • Scheduled Sampling 能缓解暴露偏差,但过早或过强地使用模型预测可能导致训练不稳定;通常需要谨慎设计衰减曲线。
  • 对于有强制对齐或外部记忆机制的模型,Scheduled Sampling 的收益可能有限。

4. 代码示意:教师强制训练一步

import torch
import torch.nn.functional as F

# logits: (B, T, V)  模型输出
# tgt_out: (B, T)    目标词ID(右移后的监督序列)
# mask: (B, T)       有效位掩码,pad位为0

def masked_ce_loss(logits, tgt_out, mask):
    B, T, V = logits.shape

    # 展平后做逐token交叉熵
    loss_per_token = F.cross_entropy(
        logits.reshape(B * T, V),
        tgt_out.reshape(B * T),
        reduction='none'
    ).reshape(B, T)

    # 仅统计有效位置
    loss = (loss_per_token * mask).sum() / mask.sum().clamp_min(1.0)
    return loss

这段代码的重点不是“炫技”,而是让你看清 mask 在训练里的位置。

5. 评估:BLEU 常用但不是唯一真相

BLEU 衡量 n-gram 重合程度,优点是可自动化批量比较;但它也有边界:

  • 对同义改写不够友好。
  • 不直接衡量事实正确性。
  • 有时会出现“BLEU 还行但句子读起来生硬”。

实践中建议 BLEU + 人工抽样 + 任务指标联合看。


四、束搜索:解码质量的关键放大器

解码时每一步都要做选择,选择策略直接决定最终句子质量。

1. 贪心搜索 vs 束搜索

方法 每步策略 速度 全局质量
贪心搜索 只取当前最大概率词 容易局部最优
束搜索 保留前k条候选路径 通常更好

贪心快但短视;束搜索稍慢但更有全局视野。

2. 束搜索流程(新手版)

  1. 初始路径只有 <bos>
  2. 每条路径扩展出若干候选词。
  3. 计算累计分数,保留 top-k 路径。
  4. 重复直到达到 <eos> 或长度上限。
  5. 从完成路径中选分数最优结果。

流程图:

开始: BOS

扩展候选

保留top-k

是否结束

返回最优序列

3. 代码示意:简化束搜索伪代码

# 仅展示逻辑,不依赖具体模型实现

def beam_search_step(beams, k, step_fn):
    # beams: [(tokens, score, state), ...]
    all_candidates = []

    for tokens, score, state in beams:
        # step_fn 返回当前路径下一步候选: [(next_id, logp, next_state), ...]
        for next_id, logp, next_state in step_fn(tokens, state):
            new_tokens = tokens + [next_id]
            new_score = score + logp  # 累加log概率
            all_candidates.append((new_tokens, new_score, next_state))

    # 按分数从大到小排序,保留top-k
    all_candidates.sort(key=lambda x: x[1], reverse=True)
    return all_candidates[:k]

# 长度归一化建议:
# - 原始累计 log 概率偏好短序列(因为 logp 累加会随长度变小),
#   常用修正为 score / (length ** alpha) 或 score / length,其中 alpha 在 0~1 之间调节强度;
#   例如 Google NMT 常用 alpha=0.7 的幂次归一化。
# - 另外可在候选分数上加入重复惩罚或覆盖惩罚来减少冗余生成。

4. 束宽怎么选更合理

束宽k 典型现象
太小(1或2) 接近贪心,质量提升有限
适中(3到8) 常是效果和速度较平衡区间
太大(>10) 计算成本上升,收益可能递减

实际建议:从 k = 3 , 5 , 8 k=3,5,8 k=3,5,8 做小网格验证,不同任务最优值可能不同。

5. 束搜索常见细节

  • 长度偏置:长句 log 概率更容易累积更小,常用长度归一化修正。
  • 重复生成:可加入重复惩罚或 coverage 约束。
  • 结束策略:有的实现按“完成路径数达到阈值”提前停止。

五、统一小结

这一篇建议你记住七件事:

  1. 机器翻译是典型 Seq2Seq,输入输出都可变长。
  2. 数据处理里,词表和特殊 token 不是细节,而是训练可行性的前提。
  3. mask 对变长批处理至关重要,不做 mask 容易把训练目标带偏。
  4. 编码器-解码器的本质是“先理解,再生成”。
  5. 训练(教师强制)和推理(自回归)输入规则不同,理解这点非常关键。
  6. BLEU 有价值但有边界,评估要多指标结合。
  7. 束搜索是实践中提升生成质量的常用方法,束宽要按任务验证。

下一步你自然会进入注意力机制与 Transformer,它们在长依赖和并行计算上进一步升级了 Seq2Seq 框架。


本章阅读顺序

这一章我拆成三篇来写,保证更细、更适合新手循序渐进阅读:

注:本文为通俗改写与学习整理,思路参考《动手学深度学习》现代循环神经网络:https://zh.d2l.ai/chapter_recurrent-modern/index.html

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐