从零实现Transformer:第 3 部分 - 掩码多头注意力的掩码广播(Broadcasting of Masks in Masked Multi-Head Attention)

flyfish

以生成填充掩码 + 前瞻掩码的组合掩码 为例

1. 生成 Padding Mask(填充掩码)

屏蔽序列中的填充占位符(pad_id=0) 填充的 0 是无效字符,模型不应该关注、学习这些无意义的占位符

2. 生成 Look-ahead Mask(前瞻掩码)

屏蔽当前位置之后的所有未来 token,解码器是自回归生成(一步步生成文本),绝对不能提前看到未来的词

3. 合并掩码

| 运算把两个掩码合二为一:
只要是「填充位」或「未来位」,统一屏蔽(True)

用处

输出形状:[batch, 1, seq_len, seq_len]
这个掩码直接传入解码器的多头自注意力层
掩码为 True → 注意力分数置为负无穷,模型完全忽略该位置
掩码为 False → 正常计算注意力,模型可以关注该位置

import torch

def create_tgt_mask(tgt_ids, pad_id):
    """创建目标序列掩码(padding mask + look-ahead mask)"""
    # 1. 2维padding掩码 [batch, seq_len]
    padding_mask_2d = (tgt_ids == pad_id)
    
    # 2. 升维适配注意力维度 -> [batch, 1, 1, seq_len]
    tgt_padding_mask = padding_mask_2d.unsqueeze(1).unsqueeze(1)
    
    # 3. 生成序列长度
    tgt_seq_len = tgt_ids.shape[1]
    
    # 4. 构造上三角前瞻掩码 [seq_len, seq_len]
    # diagonal=1:主对角线上方为1,遮挡未来位置
    look_ahead_mask = torch.triu(
        torch.ones(tgt_seq_len, tgt_seq_len, device=tgt_ids.device), 
        diagonal=1
    ).bool()
    
    # 5. 升维支持批量广播 -> [1, 1, seq_len, seq_len]
    look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)
    
    # 6. 合并掩码:任意一个为True就遮挡
    return tgt_padding_mask | look_ahead_mask

# 测试
if __name__ == "__main__":
    pad_id = 0
    # 2个batch,序列长度50为padding
    tgt_ids = torch.tensor([
        [1, 2, 3, 0, 0],
        [4, 5, 0, 0, 0]
    ])
    mask = create_tgt_mask(tgt_ids, pad_id)
    print("最终掩码形状:", mask.shape)   # torch.Size([2, 1, 5, 5])
    print("掩码内容:\n", mask)

输出

最终掩码形状: torch.Size([2, 1, 5, 5])
掩码内容:
 tensor([[[[False,  True,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False, False,  True,  True],
          [False, False, False,  True,  True],
          [False, False, False,  True,  True]]],


        [[[False,  True,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False,  True,  True,  True]]]])

广播 = PyTorch 自动把 形状不同但兼容 的张量,复制拉伸成相同形状,然后再运算

两个张量形状

return tgt_padding_mask | look_ahead_mask

两个输入形状:

  1. tgt_padding_mask[B, 1, 1, S] → 举例 [2, 1, 1, 4]
  2. look_ahead_mask[1, 1, S, S] → 举例 [1, 1, 4, 4]

广播目标:把两个张量都自动变成 [2, 1, 4, 4],再做 | 运算

广播规则

  1. 维度为 1 的位置,可以自动复制扩展成任意大小
  2. 扩展后,两个张量形状完全一致,就能运算

例子1:最简单的2维广播
模拟:小张量自动拉伸

import torch
# 形状 [1,4] → 1行4列
a = torch.tensor([[True, False, True, False]])  
# 形状 [4,4] → 4行4列
b = torch.ones(4,4).bool()  

# 广播运算:a自动复制4行,变成[4,4],再和b运算
c = a | b
print("a形状:", a.shape)
print("b形状:", b.shape)
print("广播后运算结果形状:", c.shape)  # 输出 [4,4]

[1,4] 自动扩成 [4,4]

例子2:3维广播(过渡)

# [2,1,4]
a = torch.rand(2,1,4).bool()
# [1,4,4]
b = torch.rand(1,4,4).bool()

# 自动广播成 [2,4,4]
c = a | b
print(c.shape)  # [2,4,4]

例子3:模拟代码的4维广播

import torch
# 模拟两个掩码
B, S = 2, 4
# 1. padding掩码 [2,1,1,4]
tgt_pad_mask = torch.rand(B, 1, 1, S).bool()
# 2. 前瞻掩码 [1,1,4,4]
look_ahead_mask = torch.rand(1, 1, S, S).bool()

# 广播运算!
final_mask = tgt_pad_mask | look_ahead_mask

# 打印形状
print("padding掩码形状:", tgt_pad_mask.shape)    # [2,1,1,4]
print("前瞻掩码形状:", look_ahead_mask.shape)    # [1,1,4,4]
print("广播后最终形状:", final_mask.shape)      # [2,1,4,4]

代码里用到的

输入参数

# 2个句子,每个句子最长5个词
tgt_ids = torch.tensor([
    [1, 2, 3, 0, 0],  # 第1个样本:有效词3个,后2个是填充0
    [4, 5, 0, 0, 0]   # 第2个样本:有效词2个,后3个是填充0
])
pad_id = 0  # 0代表填充位

批次大小 B = 2
序列长度 S = 5
标准掩码维度:[batch, num_heads, seq_q, seq_k]

最终维度是 [2, 1, 5, 5]

[2, 1, 5, 5]
 = [批次B, 头数H, 查询序列长Q, 键序列长K]
  1. 2:一次性处理 2 个句子(batch=2)
  2. 1:代码里没做多头,默认 1 个注意力头
  3. 5:Query 向量数量 = 目标序列长度 = 5
  4. 5:Key 向量数量 = 目标序列长度 = 5

代码里的广播

  1. tgt_padding_mask 形状:[2, 1, 1, 5]
  2. look_ahead_mask 形状:[1, 1, 5, 5]
  3. PyTorch 自动广播 把两个张量都拉伸为 [2, 1, 5, 5],再做 | 运算

掩码内容

最终掩码 = 前瞻掩码填充掩码
True = 遮挡(不让看)
False = 允许看

1. 前瞻掩码(固定不变,所有样本共用)

torch.triu(..., diagonal=1) 生成固定上三角矩阵

# 5x5 前瞻掩码(对角线以上全是True,遮挡未来词)
[
[F, T, T, T, T],  # 第1个词:只能看自己,不能看后面4个
[F, F, T, T, T],  # 第2个词:能看自己+前1个,不能看后面3个
[F, F, F, T, T],  # 第3个词:能看自己+前2个,不能看后面2个
[F, F, F, F, T],  # 第4个词:能看自己+前3个,不能看后面1个
[F, F, F, F, F]   # 第5个词:能看所有前面的词
]

2. 填充掩码(每个样本不一样)

样本1 [1,2,3,0,0]第4、5位是填充 → 掩码 [F,F,F,T,T]
样本2 [4,5,0,0,0]第3、4、5位是填充 → 掩码 [F,F,T,T,T]

最终合并结果

样本1 输出(第一块 5x5)

[[False,  True,  True,  True,  True],
 [False, False,  True,  True,  True],
 [False, False, False,  True,  True],
 [False, False, False,  True,  True],  # 第4位是填充,永久遮挡
 [False, False, False,  True,  True]]  # 第5位是填充,永久遮挡

前3行:只受前瞻掩码影响
后2行:前瞻掩码 + 填充掩码 双重遮挡

样本2 输出(第二块 5x5)

[[False,  True,  True,  True,  True],
 [False, False,  True,  True,  True],
 [False, False,  True,  True,  True],  # 第3位是填充,永久遮挡
 [False, False,  True,  True,  True],  # 第4位是填充,永久遮挡
 [False, False,  True,  True,  True]]  # 第5位是填充,永久遮挡

前2行:只受前瞻掩码影响
后3行:前瞻掩码 + 填充掩码 双重遮挡

简单的流程就是

  1. 维度 [2,1,5,5]
    [2个句子, 1个注意力头, 每个句子5个Query, 每个句子5个Key]
  2. 掩码内容
    上三角的 True = 遮挡未来词(前瞻掩码)
    后半列的 True = 遮挡填充0(填充掩码)
  3. 两者合并,就是看到的输出

不用广播的写法

import torch

def create_tgt_mask_no_broadcast(tgt_ids, pad_id):
    """创建目标序列掩码(无广播版,手动扩展维度)"""
    B, S = tgt_ids.shape  # 直接获取批次B=2,序列长S=5
    # 1. 2维padding掩码 [batch, seq_len][2,5]
    padding_mask_2d = (tgt_ids == pad_id)
    
    # 2. 升维 → [B, 1, 1, S][2,1,1,5]
    tgt_padding_mask = padding_mask_2d.unsqueeze(1).unsqueeze(1)
    
    # ============== 替代广播 ==============
    # 把第3维(seq_q)从 1 复制成 S → 形状变成 [B,1,S,S] = [2,1,5,5]
    tgt_padding_mask = tgt_padding_mask.repeat(1, 1, S, 1)

    # 3. 构造上三角前瞻掩码 [S, S][5,5]
    look_ahead_mask = torch.triu(
        torch.ones(S, S, device=tgt_ids.device), 
        diagonal=1
    ).bool()
    
    # 4. 升维 → [1, 1, S, S][1,1,5,5]
    look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0)
    
    # ============== 替代广播 ==============
    # 把第0维(batch)从 1 复制成 B → 形状变成 [B,1,S,S] = [2,1,5,5]
    look_ahead_mask = look_ahead_mask.repeat(B, 1, 1, 1)

    # 6. 两个掩码形状完全一致,直接运算(无任何广播)
    return tgt_padding_mask | look_ahead_mask

# 测试
if __name__ == "__main__":
    pad_id = 0
    tgt_ids = torch.tensor([
        [1, 2, 3, 0, 0],
        [4, 5, 0, 0, 0]
    ])
    
    # 运行无广播版本
    mask = create_tgt_mask_no_broadcast(tgt_ids, pad_id)
    print("最终掩码形状:", mask.shape)   # 依旧是 torch.Size([2, 1, 5, 5])
    print("掩码内容:\n", mask)

输出

最终掩码形状: torch.Size([2, 1, 5, 5])
掩码内容:
 tensor([[[[False,  True,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False, False,  True,  True],
          [False, False, False,  True,  True],
          [False, False, False,  True,  True]]],


        [[[False,  True,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False,  True,  True,  True],
          [False, False,  True,  True,  True]]]])
Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐