Transformer 掩码张量全解析:从核心作用到代码实现
掩码(Mask)是 Transformer 能正确工作的核心机制之一。它用来控制哪些位置可以被注意力看到。Transformer 中有两种典型掩码:填充掩码(Padding Mask)和未来信息掩码(Subsequent Mask)。Transformer 中使用填充掩码和未来信息掩码的位置(注意力层)如下表所示。
| 注意力层 | 填充掩码 | 未来信息掩码 | 核心原因 |
|---|---|---|---|
| 编码器自注意力 | ✅ 必须用 | ❌ 不用 | 源序列有 padding,无未来信息 |
| 解码器自注意力 | ✅ 必须用 | ✅ 必须用(专属) | 目标序列有 padding + 防偷看未来 |
| 编码器解码器交叉注意力 | ✅ 必须用 | ❌ 不用 | 关注编码器输出(源序列 padding) |
1 掩码张量的本质
掩码张量是一个布尔型张量(最终会转为 0/1 矩阵):
- 标记为
True:当前位置可见,自注意力可以正常关注 - 标记为
False:当前位置屏蔽,自注意力会完全忽略该位置
2 两种掩码的作用
2.1 填充掩码(Padding Mask):处理输入的不定长序列
在自然能语言处理中,句子长度参差不齐,batch 训练时需要用0填充短序列(比如把 [“我”,“爱”,“AI”] 和 [“学习”] 补成等长:[“我”,“爱”,“AI”]、[“学习”,“0”,“0”]),但这些填充的0是无意义噪声,自注意力如果关注这些位置,会学到无效信息。
填充掩码能识别序列中的填充位置,将其标记为False,让自注意力完全忽略 padding 噪声。
2.2 未来信息掩码(Subsequent Mask):防止解码器“偷看未来”
未来信息掩码的引入是为了解决一个问题——并行训练与自回归推理的矛盾。在 Transformer 的解码器中,生成过程是自回归的。也就是说,生成第 ttt 个词时,只能依赖于第 111 到 t−1t-1t−1 个词,绝对不能看到第 t+1t+1t+1 个词。
- 推理阶段:这很好办。模型生成第一个词,再根据第一个词生成第二个,依次类推。天然满足“看不见未来”。
- 训练阶段:为了加速,Transformer 使用了Teacher Forcing,将整个目标序列一次性输入模型,利用并行计算加速。
这就带来一个问题:如果不加限制,当模型计算第 ttt 个位置的注意力时,它可以通过注意力机制直接“看到”第 t+1,t+2...t+1, t+2...t+1,t+2... 个词的信息。 这就是所谓的“信息泄露”。如果模型能提前看到答案,它就会“偷懒”,直接复制答案,导致模型无法真正学习预测能力。
未来信息掩码强制模型生成第 ttt 个词时,只能关注第 111 到 t−1t-1t−1 个历史位置,彻底屏蔽 ttt 之后的所有未来位置。
在计算 Self-Attention 的缩放点积注意力时,公式如下:
Attention(Q,K,V)=softmax(QKTdk+M)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right) V Attention(Q,K,V)=softmax(dkQKT+M)V
这里的 MMM 就是掩码张量,形状和QKTQK^TQKT 计算出的相关性得分矩阵相同。
- 计算得分:QKTQK^TQKT 计算出的相关性得分矩阵。
- 应用掩码:代码中
False(对应矩阵中的 0) 的位置,通常会被加上一个巨大的负数(如 −1e9-1e9−1e9)。 - Softmax归一化:在 Softmax 阶段,e−∞≈0e^{-\infty} \approx 0e−∞≈0。这就意味着被掩码掉的位置,其权重变为 0。
结果:在计算第 ttt 行的注意力时,第 ttt 列之后的所有列权重都为 0,模型无法“看见”未来的词,只能关注当前及之前的词。
未来信息掩码的代码实现:
import torch
def subsequent_mask(size):
# 生成向后遮掩的掩码张量(未来信息掩码)
# 参数size:掩码最后两维的大小,形成方阵(对应序列长度)
attn_shape = (1, size, size) # 1是batch扩充维度,方便后续广播匹配
# 1. 生成全1张量 → 2. 取上三角矩阵(对角线以上为1,对角线及以下为0)→ 3. 转uint8节约内存
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
torch.uint8
)
# 反转布尔值:上三角(未来位置)→ False(屏蔽),非上三角(历史/当前)→ True(可见)
return subsequent_mask == 0
为什么这样实现?
(1)torch.triu 与 diagonal=1
torch.triu(torch.ones(attn_shape), diagonal=1)
triu 是 “Upper Triangle” 的缩写。diagonal=1 意味着截取的主对角线向右偏移 1 个单位。
也就是说,主对角线上的元素变成了 0。
这让掩码逻辑非常严谨:预测当前词,连当前词本身都不能看,只能看之前的词。(当然,具体的实现逻辑取决于位置编码的叠加方式,通常第 ttt 个位置的预测主要基于前文和当前位置的输入向量)。
(2)维度扩充 (1, size, size)
attn_shape = (1, size, size)
为什么不是 (size, size)?
这是为了适配 PyTorch 的广播机制。在 Transformer 中,输入通常是批量的,形状为 (batch_size, seq_len, d_model)。
Attention 得分矩阵形状为 (batch_size, num_heads, seq_len, seq_len)。
掩码张量定义为 (1, size, size)(实际上是 (1, 1, size, size) 的变体),可以自动广播到整个 Batch 和所有的注意力头上,无需手动复制,极大地节省了内存。
(3)数据类型 .type(torch.uint8)
.type(torch.uint8)
这是一个性能优化的小技巧。在 GPU 计算中,布尔运算和整型运算很快。使用 uint8(无符号8位整型)比 float32 或 bool(在某些旧版本PyTorch中)更节省显存,且计算效率高。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐




所有评论(0)